TymaaHammouda commited on
Commit
983b14c
·
verified ·
1 Parent(s): 2a3cef3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -21
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from openai import OpenAI
8
 
9
- print("Version ---- 2")
10
  app = FastAPI()
11
 
12
  # -----------------------------
@@ -26,7 +26,6 @@ def build_prompt(req1, req2, prompt_type="zero-shot"):
26
  if prompt_type == "zero-shot":
27
  return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
28
  elif prompt_type == "few-shot":
29
- # Example few-shot style (you can expand with more examples)
30
  examples = (
31
  "Example 1:\n"
32
  "Req1: The system shall allow password reset.\n"
@@ -42,7 +41,37 @@ def build_prompt(req1, req2, prompt_type="zero-shot"):
42
  return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
43
 
44
  # -----------------------------
45
- # Model handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # -----------------------------
47
  def run_gpt4(req1, req2, prompt_type, api_key):
48
  client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
@@ -56,34 +85,21 @@ def run_gpt4(req1, req2, prompt_type, api_key):
56
  return completion.choices[0].message.content.strip()
57
 
58
  def run_deepseek(req1, req2, prompt_type):
59
- model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
60
- tokenizer = AutoTokenizer.from_pretrained(model_name)
61
- model = AutoModelForCausalLM.from_pretrained(
62
- model_name,
63
- dtype=torch.bfloat16,
64
- device_map="auto"
65
- )
66
  prompt = build_prompt(req1, req2, prompt_type)
67
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
68
  outputs = model.generate(inputs.input_ids, max_new_tokens=256)
69
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
70
 
71
  def run_llama(req1, req2, prompt_type):
72
- model_name = "meta-llama/Llama-3.1-8B-Instruct"
73
- hf_token = os.getenv("LLAMA_HF_TOKEN")
74
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
75
- model = AutoModelForCausalLM.from_pretrained(
76
- model_name,
77
- token=hf_token,
78
- dtype=torch.bfloat16,
79
- device_map="auto"
80
- )
81
  prompt = build_prompt(req1, req2, prompt_type)
82
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
83
  outputs = model.generate(inputs.input_ids, max_new_tokens=256)
84
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
85
 
86
-
87
  def run_fanar(req1, req2, prompt_type):
88
  client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
89
  prompt = build_prompt(req1, req2, prompt_type)
@@ -91,7 +107,6 @@ def run_fanar(req1, req2, prompt_type):
91
  model="Fanar",
92
  messages=[{"role": "user", "content": prompt}]
93
  )
94
- print("fanar response: ", response)
95
  return response.choices[0].message.content.strip()
96
 
97
  # -----------------------------
@@ -109,6 +124,8 @@ def predict(request: ConflictDetectionRequest):
109
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
110
 
111
  elif request.model_choice == "LLaMA-3.1-8B-Instruct":
 
 
112
  answer = run_llama(request.Req1, request.Req2, request.prompt_type)
113
 
114
  elif request.model_choice == "Fanar":
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from openai import OpenAI
8
 
9
+ print("Version ---- 3")
10
  app = FastAPI()
11
 
12
  # -----------------------------
 
26
  if prompt_type == "zero-shot":
27
  return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
28
  elif prompt_type == "few-shot":
 
29
  examples = (
30
  "Example 1:\n"
31
  "Req1: The system shall allow password reset.\n"
 
41
  return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
42
 
43
  # -----------------------------
44
+ # Startup: load models once
45
+ # -----------------------------
46
+ @app.on_event("startup")
47
+ def load_models():
48
+ print("Loading models into memory...")
49
+
50
+ # DeepSeek
51
+ deepseek_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
52
+ app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_name)
53
+ app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
54
+ deepseek_name,
55
+ dtype=torch.bfloat16,
56
+ device_map="auto"
57
+ )
58
+
59
+ # LLaMA (requires HF_TOKEN secret)
60
+ llama_name = "meta-llama/Llama-3.1-8B-Instruct"
61
+ hf_token = os.getenv("HF_TOKEN")
62
+ if hf_token:
63
+ app.state.llama_tokenizer = AutoTokenizer.from_pretrained(llama_name, token=hf_token)
64
+ app.state.llama_model = AutoModelForCausalLM.from_pretrained(
65
+ llama_name,
66
+ token=hf_token,
67
+ dtype=torch.bfloat16,
68
+ device_map="auto"
69
+ )
70
+ else:
71
+ print("No HF_TOKEN found, LLaMA will not be available.")
72
+
73
+ # -----------------------------
74
+ # Model handlers (reuse loaded models)
75
  # -----------------------------
76
  def run_gpt4(req1, req2, prompt_type, api_key):
77
  client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
 
85
  return completion.choices[0].message.content.strip()
86
 
87
  def run_deepseek(req1, req2, prompt_type):
88
+ tokenizer = app.state.deepseek_tokenizer
89
+ model = app.state.deepseek_model
 
 
 
 
 
90
  prompt = build_prompt(req1, req2, prompt_type)
91
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
92
  outputs = model.generate(inputs.input_ids, max_new_tokens=256)
93
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
94
 
95
  def run_llama(req1, req2, prompt_type):
96
+ tokenizer = app.state.llama_tokenizer
97
+ model = app.state.llama_model
 
 
 
 
 
 
 
98
  prompt = build_prompt(req1, req2, prompt_type)
99
  inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
100
  outputs = model.generate(inputs.input_ids, max_new_tokens=256)
101
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
102
 
 
103
  def run_fanar(req1, req2, prompt_type):
104
  client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
105
  prompt = build_prompt(req1, req2, prompt_type)
 
107
  model="Fanar",
108
  messages=[{"role": "user", "content": prompt}]
109
  )
 
110
  return response.choices[0].message.content.strip()
111
 
112
  # -----------------------------
 
124
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
125
 
126
  elif request.model_choice == "LLaMA-3.1-8B-Instruct":
127
+ if not hasattr(app.state, "llama_model"):
128
+ return JSONResponse({"error": "LLaMA not loaded (missing HF_TOKEN)"}, status_code=400)
129
  answer = run_llama(request.Req1, request.Req2, request.prompt_type)
130
 
131
  elif request.model_choice == "Fanar":