swajall commited on
Commit
af83436
·
verified ·
1 Parent(s): 6c316ff

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +30 -13
api.py CHANGED
@@ -1,10 +1,14 @@
1
- # api.py yes
2
-
3
  import os
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
 
 
 
 
 
7
 
 
8
  os.environ["HF_HOME"] = "/tmp/hf_home"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_home"
10
  os.makedirs("/tmp/hf_home", exist_ok=True)
@@ -14,23 +18,32 @@ app = FastAPI(title="Agent Truth API")
14
  # Hugging Face token (optional if models are private)
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
 
17
- # Load NLI model (text-classification)
 
 
18
  nli_model_id = os.environ.get("NLI_MODEL", "swajall/nli-model")
19
- nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_id, token=HF_TOKEN)
 
 
20
  nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_id, token=HF_TOKEN)
 
21
  nli_pipe = pipeline(
22
  "text-classification",
23
- model=nli_model_id,
24
- tokenizer=nli_model_id,
25
  device=-1,
26
- use_auth_token=HF_TOKEN
27
  )
28
 
29
- # Load Seq2Seq model (T5)
 
 
30
  seq2_model_id = os.environ.get("SEQ2_MODEL", "swajall/seq2seq-model")
31
- tokenizer = AutoTokenizer.from_pretrained(seq2_model_id, use_auth_token=HF_TOKEN)
32
- seq2_model = AutoModelForSeq2SeqLM.from_pretrained(seq2_model_id, use_auth_token=HF_TOKEN)
33
 
 
 
 
34
  class NLIRequest(BaseModel):
35
  premise: str
36
  hypothesis: str
@@ -38,13 +51,17 @@ class NLIRequest(BaseModel):
38
  class Seq2SeqRequest(BaseModel):
39
  transcript: str
40
 
 
 
 
41
  @app.get("/")
42
  def root():
43
- return {"msg": "Agent Truth API is running"}
44
 
45
  @app.post("/nli")
46
  def nli(req: NLIRequest):
47
- res = nli_pipe((req.premise, req.hypothesis))
 
48
  return {"result": res}
49
 
50
  @app.post("/seq2seq")
 
 
 
1
  import os
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
+ from transformers import (
5
+ pipeline,
6
+ AutoTokenizer,
7
+ AutoModelForSeq2SeqLM,
8
+ AutoModelForSequenceClassification,
9
+ )
10
 
11
+ # Ensure Hugging Face cache path
12
  os.environ["HF_HOME"] = "/tmp/hf_home"
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_home"
14
  os.makedirs("/tmp/hf_home", exist_ok=True)
 
18
  # Hugging Face token (optional if models are private)
19
  HF_TOKEN = os.environ.get("HF_TOKEN")
20
 
21
+ # ---------------------------
22
+ # Load NLI model (sequence classification)
23
+ # ---------------------------
24
  nli_model_id = os.environ.get("NLI_MODEL", "swajall/nli-model")
25
+ nli_model = AutoModelForSequenceClassification.from_pretrained(
26
+ nli_model_id, token=HF_TOKEN
27
+ )
28
  nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_id, token=HF_TOKEN)
29
+
30
  nli_pipe = pipeline(
31
  "text-classification",
32
+ model=nli_model,
33
+ tokenizer=nli_tokenizer,
34
  device=-1,
 
35
  )
36
 
37
+ # ---------------------------
38
+ # Load Seq2Seq model (T5 family)
39
+ # ---------------------------
40
  seq2_model_id = os.environ.get("SEQ2_MODEL", "swajall/seq2seq-model")
41
+ tokenizer = AutoTokenizer.from_pretrained(seq2_model_id, token=HF_TOKEN)
42
+ seq2_model = AutoModelForSeq2SeqLM.from_pretrained(seq2_model_id, token=HF_TOKEN)
43
 
44
+ # ---------------------------
45
+ # Request Schemas
46
+ # ---------------------------
47
  class NLIRequest(BaseModel):
48
  premise: str
49
  hypothesis: str
 
51
  class Seq2SeqRequest(BaseModel):
52
  transcript: str
53
 
54
+ # ---------------------------
55
+ # Routes
56
+ # ---------------------------
57
  @app.get("/")
58
  def root():
59
+ return {"msg": "Agent Truth API is running 🚀"}
60
 
61
  @app.post("/nli")
62
  def nli(req: NLIRequest):
63
+ # Correct input format for text + hypothesis
64
+ res = nli_pipe({"text": req.premise, "text_pair": req.hypothesis})
65
  return {"result": res}
66
 
67
  @app.post("/seq2seq")