truth-agent / api.py
swajall's picture
Update api.py
af83436 verified
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
)
# Ensure Hugging Face cache path
os.environ["HF_HOME"] = "/tmp/hf_home"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_home"
os.makedirs("/tmp/hf_home", exist_ok=True)
app = FastAPI(title="Agent Truth API")
# Hugging Face token (optional if models are private)
HF_TOKEN = os.environ.get("HF_TOKEN")
# ---------------------------
# Load NLI model (sequence classification)
# ---------------------------
nli_model_id = os.environ.get("NLI_MODEL", "swajall/nli-model")
nli_model = AutoModelForSequenceClassification.from_pretrained(
nli_model_id, token=HF_TOKEN
)
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_id, token=HF_TOKEN)
nli_pipe = pipeline(
"text-classification",
model=nli_model,
tokenizer=nli_tokenizer,
device=-1,
)
# ---------------------------
# Load Seq2Seq model (T5 family)
# ---------------------------
seq2_model_id = os.environ.get("SEQ2_MODEL", "swajall/seq2seq-model")
tokenizer = AutoTokenizer.from_pretrained(seq2_model_id, token=HF_TOKEN)
seq2_model = AutoModelForSeq2SeqLM.from_pretrained(seq2_model_id, token=HF_TOKEN)
# ---------------------------
# Request Schemas
# ---------------------------
class NLIRequest(BaseModel):
premise: str
hypothesis: str
class Seq2SeqRequest(BaseModel):
transcript: str
# ---------------------------
# Routes
# ---------------------------
@app.get("/")
def root():
return {"msg": "Agent Truth API is running 🚀"}
@app.post("/nli")
def nli(req: NLIRequest):
# Correct input format for text + hypothesis
res = nli_pipe({"text": req.premise, "text_pair": req.hypothesis})
return {"result": res}
@app.post("/seq2seq")
def seq2seq(req: Seq2SeqRequest):
inputs = tokenizer(req.transcript, return_tensors="pt", truncation=True, padding=True)
outputs = seq2_model.generate(**inputs, max_length=256)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"truth_json": text}