import os from fastapi import FastAPI from pydantic import BaseModel from transformers import ( pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, ) # Ensure Hugging Face cache path os.environ["HF_HOME"] = "/tmp/hf_home" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_home" os.makedirs("/tmp/hf_home", exist_ok=True) app = FastAPI(title="Agent Truth API") # Hugging Face token (optional if models are private) HF_TOKEN = os.environ.get("HF_TOKEN") # --------------------------- # Load NLI model (sequence classification) # --------------------------- nli_model_id = os.environ.get("NLI_MODEL", "swajall/nli-model") nli_model = AutoModelForSequenceClassification.from_pretrained( nli_model_id, token=HF_TOKEN ) nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_id, token=HF_TOKEN) nli_pipe = pipeline( "text-classification", model=nli_model, tokenizer=nli_tokenizer, device=-1, ) # --------------------------- # Load Seq2Seq model (T5 family) # --------------------------- seq2_model_id = os.environ.get("SEQ2_MODEL", "swajall/seq2seq-model") tokenizer = AutoTokenizer.from_pretrained(seq2_model_id, token=HF_TOKEN) seq2_model = AutoModelForSeq2SeqLM.from_pretrained(seq2_model_id, token=HF_TOKEN) # --------------------------- # Request Schemas # --------------------------- class NLIRequest(BaseModel): premise: str hypothesis: str class Seq2SeqRequest(BaseModel): transcript: str # --------------------------- # Routes # --------------------------- @app.get("/") def root(): return {"msg": "Agent Truth API is running 🚀"} @app.post("/nli") def nli(req: NLIRequest): # Correct input format for text + hypothesis res = nli_pipe({"text": req.premise, "text_pair": req.hypothesis}) return {"result": res} @app.post("/seq2seq") def seq2seq(req: Seq2SeqRequest): inputs = tokenizer(req.transcript, return_tensors="pt", truncation=True, padding=True) outputs = seq2_model.generate(**inputs, max_length=256) text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"truth_json": text}