Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import ( | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| AutoModelForSequenceClassification, | |
| ) | |
| # Ensure Hugging Face cache path | |
| os.environ["HF_HOME"] = "/tmp/hf_home" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_home" | |
| os.makedirs("/tmp/hf_home", exist_ok=True) | |
| app = FastAPI(title="Agent Truth API") | |
| # Hugging Face token (optional if models are private) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # --------------------------- | |
| # Load NLI model (sequence classification) | |
| # --------------------------- | |
| nli_model_id = os.environ.get("NLI_MODEL", "swajall/nli-model") | |
| nli_model = AutoModelForSequenceClassification.from_pretrained( | |
| nli_model_id, token=HF_TOKEN | |
| ) | |
| nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_id, token=HF_TOKEN) | |
| nli_pipe = pipeline( | |
| "text-classification", | |
| model=nli_model, | |
| tokenizer=nli_tokenizer, | |
| device=-1, | |
| ) | |
| # --------------------------- | |
| # Load Seq2Seq model (T5 family) | |
| # --------------------------- | |
| seq2_model_id = os.environ.get("SEQ2_MODEL", "swajall/seq2seq-model") | |
| tokenizer = AutoTokenizer.from_pretrained(seq2_model_id, token=HF_TOKEN) | |
| seq2_model = AutoModelForSeq2SeqLM.from_pretrained(seq2_model_id, token=HF_TOKEN) | |
| # --------------------------- | |
| # Request Schemas | |
| # --------------------------- | |
| class NLIRequest(BaseModel): | |
| premise: str | |
| hypothesis: str | |
| class Seq2SeqRequest(BaseModel): | |
| transcript: str | |
| # --------------------------- | |
| # Routes | |
| # --------------------------- | |
| def root(): | |
| return {"msg": "Agent Truth API is running 🚀"} | |
| def nli(req: NLIRequest): | |
| # Correct input format for text + hypothesis | |
| res = nli_pipe({"text": req.premise, "text_pair": req.hypothesis}) | |
| return {"result": res} | |
| def seq2seq(req: Seq2SeqRequest): | |
| inputs = tokenizer(req.transcript, return_tensors="pt", truncation=True, padding=True) | |
| outputs = seq2_model.generate(**inputs, max_length=256) | |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return {"truth_json": text} | |