TymaaHammouda commited on
Commit
0948bff
·
verified ·
1 Parent(s): 8442332

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -48
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from openai import OpenAI
8
 
9
- print("Version ---- 4")
10
  app = FastAPI()
11
 
12
  # -----------------------------
@@ -15,7 +15,7 @@ app = FastAPI()
15
  class ConflictDetectionRequest(BaseModel):
16
  Req1: str
17
  Req2: str
18
- model_choice: str # "GPT-4", "DeepSeek-Reasoner", "LLaMA-3.1-8B-Instruct", "Fanar"
19
  prompt_type: str # "zero-shot" or "few-shot"
20
  api_key: str = None # required only if model_choice == "GPT-4"
21
 
@@ -41,39 +41,21 @@ def build_prompt(req1, req2, prompt_type="zero-shot"):
41
  return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
42
 
43
  # -----------------------------
44
- # Startup: load models once
45
  # -----------------------------
46
  @app.on_event("startup")
47
  def load_models():
48
- print("Loading models into memory...")
49
-
50
- # DeepSeek
51
  deepseek_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
52
  app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_name)
53
  app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token
54
  app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
55
  deepseek_name,
56
- dtype=torch.bfloat16,
57
- device_map="auto"
58
  )
59
 
60
- # LLaMA (requires HF_TOKEN secret)
61
- # llama_name = "meta-llama/Llama-3.1-8B-Instruct"
62
- # hf_token = os.getenv("LLAMA_HF_TOKEN")
63
- # if hf_token:
64
- # app.state.llama_tokenizer = AutoTokenizer.from_pretrained(llama_name, token=hf_token)
65
- # app.state.llama_tokenizer.pad_token = app.state.llama_tokenizer.eos_token
66
- # app.state.llama_model = AutoModelForCausalLM.from_pretrained(
67
- # llama_name,
68
- # token=hf_token,
69
- # dtype=torch.bfloat16,
70
- # device_map="auto"
71
- # )
72
- # else:
73
- # print("No HF_TOKEN found, LLaMA will not be available.")
74
-
75
  # -----------------------------
76
- # Model handlers (reuse loaded models)
77
  # -----------------------------
78
  def run_gpt4(req1, req2, prompt_type, api_key):
79
  client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
@@ -95,7 +77,7 @@ def run_deepseek(req1, req2, prompt_type):
95
  return_tensors="pt",
96
  padding=True,
97
  truncation=True
98
- ).to(model.device)
99
  outputs = model.generate(
100
  input_ids=inputs.input_ids,
101
  attention_mask=inputs.attention_mask,
@@ -104,24 +86,6 @@ def run_deepseek(req1, req2, prompt_type):
104
  )
105
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
106
 
107
- # def run_llama(req1, req2, prompt_type):
108
- # tokenizer = app.state.llama_tokenizer
109
- # model = app.state.llama_model
110
- # prompt = build_prompt(req1, req2, prompt_type)
111
- # inputs = tokenizer(
112
- # [prompt],
113
- # return_tensors="pt",
114
- # padding=True,
115
- # truncation=True
116
- # ).to(model.device)
117
- # outputs = model.generate(
118
- # input_ids=inputs.input_ids,
119
- # attention_mask=inputs.attention_mask,
120
- # max_new_tokens=256,
121
- # pad_token_id=tokenizer.eos_token_id
122
- # )
123
- # return tokenizer.decode(outputs[0], skip_special_tokens=True)
124
-
125
  def run_fanar(req1, req2, prompt_type):
126
  client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
127
  prompt = build_prompt(req1, req2, prompt_type)
@@ -145,11 +109,6 @@ def predict(request: ConflictDetectionRequest):
145
  elif request.model_choice == "DeepSeek-Reasoner":
146
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
147
 
148
- # elif request.model_choice == "LLaMA-3.1-8B-Instruct":
149
- # if not hasattr(app.state, "llama_model"):
150
- # return JSONResponse({"error": "LLaMA not loaded (missing HF_TOKEN)"}, status_code=400)
151
- # answer = run_llama(request.Req1, request.Req2, request.prompt_type)
152
-
153
  elif request.model_choice == "Fanar":
154
  answer = run_fanar(request.Req1, request.Req2, request.prompt_type)
155
 
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from openai import OpenAI
8
 
9
+ print("Version ---- DeepSeek Only")
10
  app = FastAPI()
11
 
12
  # -----------------------------
 
15
  class ConflictDetectionRequest(BaseModel):
16
  Req1: str
17
  Req2: str
18
+ model_choice: str # "GPT-4", "DeepSeek-Reasoner", "Fanar"
19
  prompt_type: str # "zero-shot" or "few-shot"
20
  api_key: str = None # required only if model_choice == "GPT-4"
21
 
 
41
  return f"Do the following sentences contradict each other, answer with just yes or no: 1.{req1} 2.{req2}"
42
 
43
  # -----------------------------
44
+ # Startup: load DeepSeek once
45
  # -----------------------------
46
  @app.on_event("startup")
47
  def load_models():
48
+ print("Loading DeepSeek model into memory...")
 
 
49
  deepseek_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
50
  app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_name)
51
  app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token
52
  app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
53
  deepseek_name,
54
+ torch_dtype=torch.float32 # CPU only
 
55
  )
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # -----------------------------
58
+ # Model handlers
59
  # -----------------------------
60
  def run_gpt4(req1, req2, prompt_type, api_key):
61
  client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
 
77
  return_tensors="pt",
78
  padding=True,
79
  truncation=True
80
+ )
81
  outputs = model.generate(
82
  input_ids=inputs.input_ids,
83
  attention_mask=inputs.attention_mask,
 
86
  )
87
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def run_fanar(req1, req2, prompt_type):
90
  client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
91
  prompt = build_prompt(req1, req2, prompt_type)
 
109
  elif request.model_choice == "DeepSeek-Reasoner":
110
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
111
 
 
 
 
 
 
112
  elif request.model_choice == "Fanar":
113
  answer = run_fanar(request.Req1, request.Req2, request.prompt_type)
114