File size: 2,122 Bytes
255dff8
 
 
af83436
 
 
 
 
 
6c316ff
af83436
4a81ef1
44547d1
4a81ef1
255dff8
 
 
 
 
 
af83436
 
 
255dff8
af83436
 
 
6c316ff
af83436
255dff8
 
af83436
 
255dff8
 
 
af83436
 
 
255dff8
af83436
 
255dff8
af83436
 
 
255dff8
 
 
 
 
 
 
af83436
 
 
255dff8
 
af83436
255dff8
 
 
af83436
 
255dff8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
)

# Ensure Hugging Face cache path
os.environ["HF_HOME"] = "/tmp/hf_home"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_home"
os.makedirs("/tmp/hf_home", exist_ok=True)

app = FastAPI(title="Agent Truth API")

# Hugging Face token (optional if models are private)
HF_TOKEN = os.environ.get("HF_TOKEN")

# ---------------------------
# Load NLI model (sequence classification)
# ---------------------------
nli_model_id = os.environ.get("NLI_MODEL", "swajall/nli-model")
nli_model = AutoModelForSequenceClassification.from_pretrained(
    nli_model_id, token=HF_TOKEN
)
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_id, token=HF_TOKEN)

nli_pipe = pipeline(
    "text-classification",
    model=nli_model,
    tokenizer=nli_tokenizer,
    device=-1,
)

# ---------------------------
# Load Seq2Seq model (T5 family)
# ---------------------------
seq2_model_id = os.environ.get("SEQ2_MODEL", "swajall/seq2seq-model")
tokenizer = AutoTokenizer.from_pretrained(seq2_model_id, token=HF_TOKEN)
seq2_model = AutoModelForSeq2SeqLM.from_pretrained(seq2_model_id, token=HF_TOKEN)

# ---------------------------
# Request Schemas
# ---------------------------
class NLIRequest(BaseModel):
    premise: str
    hypothesis: str

class Seq2SeqRequest(BaseModel):
    transcript: str

# ---------------------------
# Routes
# ---------------------------
@app.get("/")
def root():
    return {"msg": "Agent Truth API is running 🚀"}

@app.post("/nli")
def nli(req: NLIRequest):
    # Correct input format for text + hypothesis
    res = nli_pipe({"text": req.premise, "text_pair": req.hypothesis})
    return {"result": res}

@app.post("/seq2seq")
def seq2seq(req: Seq2SeqRequest):
    inputs = tokenizer(req.transcript, return_tensors="pt", truncation=True, padding=True)
    outputs = seq2_model.generate(**inputs, max_length=256)
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {"truth_json": text}