TymaaHammouda commited on
Commit
8442332
·
verified ·
1 Parent(s): 490fdc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -35
app.py CHANGED
@@ -55,22 +55,22 @@ def load_models():
55
  deepseek_name,
56
  dtype=torch.bfloat16,
57
  device_map="auto"
58
- ).to("cuda" if torch.cuda.is_available() else "cpu")
59
 
60
  # LLaMA (requires HF_TOKEN secret)
61
- llama_name = "meta-llama/Llama-3.1-8B-Instruct"
62
- hf_token = os.getenv("LLAMA_HF_TOKEN")
63
- if hf_token:
64
- app.state.llama_tokenizer = AutoTokenizer.from_pretrained(llama_name, token=hf_token)
65
- app.state.llama_tokenizer.pad_token = app.state.llama_tokenizer.eos_token
66
- app.state.llama_model = AutoModelForCausalLM.from_pretrained(
67
- llama_name,
68
- token=hf_token,
69
- dtype=torch.bfloat16,
70
- device_map="auto"
71
- )
72
- else:
73
- print("No HF_TOKEN found, LLaMA will not be available.")
74
 
75
  # -----------------------------
76
  # Model handlers (reuse loaded models)
@@ -104,23 +104,23 @@ def run_deepseek(req1, req2, prompt_type):
104
  )
105
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
106
 
107
- def run_llama(req1, req2, prompt_type):
108
- tokenizer = app.state.llama_tokenizer
109
- model = app.state.llama_model
110
- prompt = build_prompt(req1, req2, prompt_type)
111
- inputs = tokenizer(
112
- [prompt],
113
- return_tensors="pt",
114
- padding=True,
115
- truncation=True
116
- ).to(model.device)
117
- outputs = model.generate(
118
- input_ids=inputs.input_ids,
119
- attention_mask=inputs.attention_mask,
120
- max_new_tokens=256,
121
- pad_token_id=tokenizer.eos_token_id
122
- )
123
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
124
 
125
  def run_fanar(req1, req2, prompt_type):
126
  client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
@@ -145,10 +145,10 @@ def predict(request: ConflictDetectionRequest):
145
  elif request.model_choice == "DeepSeek-Reasoner":
146
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
147
 
148
- elif request.model_choice == "LLaMA-3.1-8B-Instruct":
149
- if not hasattr(app.state, "llama_model"):
150
- return JSONResponse({"error": "LLaMA not loaded (missing HF_TOKEN)"}, status_code=400)
151
- answer = run_llama(request.Req1, request.Req2, request.prompt_type)
152
 
153
  elif request.model_choice == "Fanar":
154
  answer = run_fanar(request.Req1, request.Req2, request.prompt_type)
 
55
  deepseek_name,
56
  dtype=torch.bfloat16,
57
  device_map="auto"
58
+ )
59
 
60
  # LLaMA (requires HF_TOKEN secret)
61
+ # llama_name = "meta-llama/Llama-3.1-8B-Instruct"
62
+ # hf_token = os.getenv("LLAMA_HF_TOKEN")
63
+ # if hf_token:
64
+ # app.state.llama_tokenizer = AutoTokenizer.from_pretrained(llama_name, token=hf_token)
65
+ # app.state.llama_tokenizer.pad_token = app.state.llama_tokenizer.eos_token
66
+ # app.state.llama_model = AutoModelForCausalLM.from_pretrained(
67
+ # llama_name,
68
+ # token=hf_token,
69
+ # dtype=torch.bfloat16,
70
+ # device_map="auto"
71
+ # )
72
+ # else:
73
+ # print("No HF_TOKEN found, LLaMA will not be available.")
74
 
75
  # -----------------------------
76
  # Model handlers (reuse loaded models)
 
104
  )
105
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
106
 
107
+ # def run_llama(req1, req2, prompt_type):
108
+ # tokenizer = app.state.llama_tokenizer
109
+ # model = app.state.llama_model
110
+ # prompt = build_prompt(req1, req2, prompt_type)
111
+ # inputs = tokenizer(
112
+ # [prompt],
113
+ # return_tensors="pt",
114
+ # padding=True,
115
+ # truncation=True
116
+ # ).to(model.device)
117
+ # outputs = model.generate(
118
+ # input_ids=inputs.input_ids,
119
+ # attention_mask=inputs.attention_mask,
120
+ # max_new_tokens=256,
121
+ # pad_token_id=tokenizer.eos_token_id
122
+ # )
123
+ # return tokenizer.decode(outputs[0], skip_special_tokens=True)
124
 
125
  def run_fanar(req1, req2, prompt_type):
126
  client = OpenAI(base_url="https://api.fanar.qa/v1", api_key=os.getenv("FANAR_API"))
 
145
  elif request.model_choice == "DeepSeek-Reasoner":
146
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
147
 
148
+ # elif request.model_choice == "LLaMA-3.1-8B-Instruct":
149
+ # if not hasattr(app.state, "llama_model"):
150
+ # return JSONResponse({"error": "LLaMA not loaded (missing HF_TOKEN)"}, status_code=400)
151
+ # answer = run_llama(request.Req1, request.Req2, request.prompt_type)
152
 
153
  elif request.model_choice == "Fanar":
154
  answer = run_fanar(request.Req1, request.Req2, request.prompt_type)