TymaaHammouda commited on
Commit
490fdc6
·
verified ·
1 Parent(s): 8c3bcb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -53
app.py CHANGED
@@ -6,74 +6,121 @@ import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from openai import OpenAI
8
 
 
9
  app = FastAPI()
10
 
11
- # Globals for models
12
- deepseek_model = None
13
- deepseek_tokenizer = None
14
- llama_model = None
15
- llama_tokenizer = None
16
-
17
  class ConflictDetectionRequest(BaseModel):
18
  Req1: str
19
  Req2: str
20
  model_choice: str # "GPT-4", "DeepSeek-Reasoner", "LLaMA-3.1-8B-Instruct", "Fanar"
21
  prompt_type: str # "zero-shot" or "few-shot"
22
- api_key: str = None
23
 
 
 
 
24
  def build_prompt(req1, req2, prompt_type="zero-shot"):
25
  if prompt_type == "zero-shot":
26
  return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
27
- else:
28
  examples = (
29
- "Example 1:\nReq1: The system shall allow password reset.\nReq2: The system shall not allow password reset.\nAnswer: yes\n\n"
30
- "Example 2:\nReq1: The system shall support Arabic language.\nReq2: The system shall support English language.\nAnswer: no\n\n"
 
 
 
 
 
 
31
  )
32
  return examples + f"Now answer: Do the following sentences contradict each other? 1.{req1} 2.{req2}"
 
 
33
 
34
- def run_deepseek(req1, req2, prompt_type):
35
- global deepseek_model, deepseek_tokenizer
36
- if deepseek_model is None:
37
- print("Loading DeepSeek model into memory...")
38
- model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
39
- deepseek_tokenizer = AutoTokenizer.from_pretrained(model_name)
40
- deepseek_tokenizer.pad_token = deepseek_tokenizer.eos_token
41
- deepseek_model = AutoModelForCausalLM.from_pretrained(
42
- model_name,
43
- torch_dtype=torch.float32 # CPU only
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
 
 
 
 
 
 
 
 
45
  prompt = build_prompt(req1, req2, prompt_type)
46
- inputs = deepseek_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
47
- outputs = deepseek_model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  input_ids=inputs.input_ids,
49
  attention_mask=inputs.attention_mask,
50
  max_new_tokens=256,
51
- pad_token_id=deepseek_tokenizer.eos_token_id
52
  )
53
- return deepseek_tokenizer.decode(outputs[0], skip_special_tokens=True)
54
 
55
  def run_llama(req1, req2, prompt_type):
56
- global llama_model, llama_tokenizer
57
- if llama_model is None:
58
- print("Loading LLaMA model into memory...")
59
- model_name = "meta-llama/Llama-3.1-8B-Instruct"
60
- hf_token = os.getenv("LLAMA_HF_TOKEN")
61
- llama_tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
62
- llama_tokenizer.pad_token = llama_tokenizer.eos_token
63
- llama_model = AutoModelForCausalLM.from_pretrained(
64
- model_name,
65
- token=hf_token,
66
- torch_dtype=torch.float32 # CPU only
67
- )
68
  prompt = build_prompt(req1, req2, prompt_type)
69
- inputs = llama_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
70
- outputs = llama_model.generate(
 
 
 
 
 
71
  input_ids=inputs.input_ids,
72
  attention_mask=inputs.attention_mask,
73
  max_new_tokens=256,
74
- pad_token_id=llama_tokenizer.eos_token_id
75
  )
76
- return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
77
 
78
  def run_fanar(req1, req2, prompt_type):
79
  client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
@@ -84,30 +131,32 @@ def run_fanar(req1, req2, prompt_type):
84
  )
85
  return response.choices[0].message.content.strip()
86
 
 
 
 
87
  @app.post("/predict")
88
  def predict(request: ConflictDetectionRequest):
89
  try:
90
- if request.model_choice == "DeepSeek-Reasoner":
 
 
 
 
 
91
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
 
92
  elif request.model_choice == "LLaMA-3.1-8B-Instruct":
 
 
93
  answer = run_llama(request.Req1, request.Req2, request.prompt_type)
 
94
  elif request.model_choice == "Fanar":
95
  answer = run_fanar(request.Req1, request.Req2, request.prompt_type)
96
- elif request.model_choice == "GPT-4":
97
- if not request.api_key:
98
- return JSONResponse({"error": "API key required for GPT-4"}, status_code=400)
99
- client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=request.api_key)
100
- prompt = build_prompt(request.Req1, request.Req2, request.prompt_type)
101
- completion = client.chat.completions.create(
102
- model="openai/gpt-4",
103
- messages=[{"role": "user", "content": prompt}],
104
- temperature=0.7,
105
- max_tokens=512
106
- )
107
- answer = completion.choices[0].message.content.strip()
108
  else:
109
  return JSONResponse({"error": "Invalid model_choice"}, status_code=400)
110
 
111
  return JSONResponse({"resp": answer, "statusText": "OK", "statusCode": 0}, status_code=200)
 
112
  except Exception as e:
113
  return JSONResponse({"error": str(e)}, status_code=500)
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from openai import OpenAI
8
 
9
+ print("Version ---- 4")
10
  app = FastAPI()
11
 
12
+ # -----------------------------
13
+ # Request schema
14
+ # -----------------------------
 
 
 
15
  class ConflictDetectionRequest(BaseModel):
16
  Req1: str
17
  Req2: str
18
  model_choice: str # "GPT-4", "DeepSeek-Reasoner", "LLaMA-3.1-8B-Instruct", "Fanar"
19
  prompt_type: str # "zero-shot" or "few-shot"
20
+ api_key: str = None # required only if model_choice == "GPT-4"
21
 
22
+ # -----------------------------
23
+ # Prompt builder
24
+ # -----------------------------
25
  def build_prompt(req1, req2, prompt_type="zero-shot"):
26
  if prompt_type == "zero-shot":
27
  return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
28
+ elif prompt_type == "few-shot":
29
  examples = (
30
+ "Example 1:\n"
31
+ "Req1: The system shall allow password reset.\n"
32
+ "Req2: The system shall not allow password reset.\n"
33
+ "Answer: yes\n\n"
34
+ "Example 2:\n"
35
+ "Req1: The system shall support Arabic language.\n"
36
+ "Req2: The system shall support English language.\n"
37
+ "Answer: no\n\n"
38
  )
39
  return examples + f"Now answer: Do the following sentences contradict each other? 1.{req1} 2.{req2}"
40
+ else:
41
+ return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
42
 
43
+ # -----------------------------
44
+ # Startup: load models once
45
+ # -----------------------------
46
+ @app.on_event("startup")
47
+ def load_models():
48
+ print("Loading models into memory...")
49
+
50
+ # DeepSeek
51
+ deepseek_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
52
+ app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_name)
53
+ app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token
54
+ app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
55
+ deepseek_name,
56
+ dtype=torch.bfloat16,
57
+ device_map="auto"
58
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
59
+
60
+ # LLaMA (requires HF_TOKEN secret)
61
+ llama_name = "meta-llama/Llama-3.1-8B-Instruct"
62
+ hf_token = os.getenv("LLAMA_HF_TOKEN")
63
+ if hf_token:
64
+ app.state.llama_tokenizer = AutoTokenizer.from_pretrained(llama_name, token=hf_token)
65
+ app.state.llama_tokenizer.pad_token = app.state.llama_tokenizer.eos_token
66
+ app.state.llama_model = AutoModelForCausalLM.from_pretrained(
67
+ llama_name,
68
+ token=hf_token,
69
+ dtype=torch.bfloat16,
70
+ device_map="auto"
71
  )
72
+ else:
73
+ print("No HF_TOKEN found, LLaMA will not be available.")
74
+
75
+ # -----------------------------
76
+ # Model handlers (reuse loaded models)
77
+ # -----------------------------
78
+ def run_gpt4(req1, req2, prompt_type, api_key):
79
+ client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
80
  prompt = build_prompt(req1, req2, prompt_type)
81
+ completion = client.chat.completions.create(
82
+ model="openai/gpt-4",
83
+ messages=[{"role": "user", "content": prompt}],
84
+ temperature=0.7,
85
+ max_tokens=512
86
+ )
87
+ return completion.choices[0].message.content.strip()
88
+
89
+ def run_deepseek(req1, req2, prompt_type):
90
+ tokenizer = app.state.deepseek_tokenizer
91
+ model = app.state.deepseek_model
92
+ prompt = build_prompt(req1, req2, prompt_type)
93
+ inputs = tokenizer(
94
+ [prompt],
95
+ return_tensors="pt",
96
+ padding=True,
97
+ truncation=True
98
+ ).to(model.device)
99
+ outputs = model.generate(
100
  input_ids=inputs.input_ids,
101
  attention_mask=inputs.attention_mask,
102
  max_new_tokens=256,
103
+ pad_token_id=tokenizer.eos_token_id
104
  )
105
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
106
 
107
  def run_llama(req1, req2, prompt_type):
108
+ tokenizer = app.state.llama_tokenizer
109
+ model = app.state.llama_model
 
 
 
 
 
 
 
 
 
 
110
  prompt = build_prompt(req1, req2, prompt_type)
111
+ inputs = tokenizer(
112
+ [prompt],
113
+ return_tensors="pt",
114
+ padding=True,
115
+ truncation=True
116
+ ).to(model.device)
117
+ outputs = model.generate(
118
  input_ids=inputs.input_ids,
119
  attention_mask=inputs.attention_mask,
120
  max_new_tokens=256,
121
+ pad_token_id=tokenizer.eos_token_id
122
  )
123
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
124
 
125
  def run_fanar(req1, req2, prompt_type):
126
  client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
 
131
  )
132
  return response.choices[0].message.content.strip()
133
 
134
+ # -----------------------------
135
+ # API route
136
+ # -----------------------------
137
  @app.post("/predict")
138
  def predict(request: ConflictDetectionRequest):
139
  try:
140
+ if request.model_choice == "GPT-4":
141
+ if not request.api_key:
142
+ return JSONResponse({"error": "API key required for GPT-4"}, status_code=400)
143
+ answer = run_gpt4(request.Req1, request.Req2, request.prompt_type, request.api_key)
144
+
145
+ elif request.model_choice == "DeepSeek-Reasoner":
146
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
147
+
148
  elif request.model_choice == "LLaMA-3.1-8B-Instruct":
149
+ if not hasattr(app.state, "llama_model"):
150
+ return JSONResponse({"error": "LLaMA not loaded (missing HF_TOKEN)"}, status_code=400)
151
  answer = run_llama(request.Req1, request.Req2, request.prompt_type)
152
+
153
  elif request.model_choice == "Fanar":
154
  answer = run_fanar(request.Req1, request.Req2, request.prompt_type)
155
+
 
 
 
 
 
 
 
 
 
 
 
156
  else:
157
  return JSONResponse({"error": "Invalid model_choice"}, status_code=400)
158
 
159
  return JSONResponse({"resp": answer, "statusText": "OK", "statusCode": 0}, status_code=200)
160
+
161
  except Exception as e:
162
  return JSONResponse({"error": str(e)}, status_code=500)