TymaaHammouda commited on
Commit
08eeadf
·
verified ·
1 Parent(s): f36f407

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -9
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from openai import OpenAI
8
 
9
- print("Version ---- 3")
10
  app = FastAPI()
11
 
12
  # -----------------------------
@@ -50,6 +50,7 @@ def load_models():
50
  # DeepSeek
51
  deepseek_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
52
  app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_name)
 
53
  app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
54
  deepseek_name,
55
  dtype=torch.bfloat16,
@@ -61,6 +62,7 @@ def load_models():
61
  hf_token = os.getenv("LLAMA_HF_TOKEN")
62
  if hf_token:
63
  app.state.llama_tokenizer = AutoTokenizer.from_pretrained(llama_name, token=hf_token)
 
64
  app.state.llama_model = AutoModelForCausalLM.from_pretrained(
65
  llama_name,
66
  token=hf_token,
@@ -85,22 +87,39 @@ def run_gpt4(req1, req2, prompt_type, api_key):
85
  return completion.choices[0].message.content.strip()
86
 
87
  def run_deepseek(req1, req2, prompt_type):
88
- print("Start run deepseek")
89
  tokenizer = app.state.deepseek_tokenizer
90
  model = app.state.deepseek_model
91
- print("Start prompt building")
92
  prompt = build_prompt(req1, req2, prompt_type)
93
- print("The prompt is ", prompt)
94
- inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
95
- outputs = model.generate(inputs.input_ids, max_new_tokens=256)
 
 
 
 
 
 
 
 
 
96
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
97
 
98
  def run_llama(req1, req2, prompt_type):
99
  tokenizer = app.state.llama_tokenizer
100
  model = app.state.llama_model
101
  prompt = build_prompt(req1, req2, prompt_type)
102
- inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
103
- outputs = model.generate(inputs.input_ids, max_new_tokens=256)
 
 
 
 
 
 
 
 
 
 
104
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
105
 
106
  def run_fanar(req1, req2, prompt_type):
@@ -125,7 +144,7 @@ def predict(request: ConflictDetectionRequest):
125
 
126
  elif request.model_choice == "DeepSeek-Reasoner":
127
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
128
- print("Deepseek answer is : ", answer)
129
  elif request.model_choice == "LLaMA-3.1-8B-Instruct":
130
  if not hasattr(app.state, "llama_model"):
131
  return JSONResponse({"error": "LLaMA not loaded (missing HF_TOKEN)"}, status_code=400)
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from openai import OpenAI
8
 
9
+ print("Version ---- 4")
10
  app = FastAPI()
11
 
12
  # -----------------------------
 
50
  # DeepSeek
51
  deepseek_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
52
  app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_name)
53
+ app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token
54
  app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
55
  deepseek_name,
56
  dtype=torch.bfloat16,
 
62
  hf_token = os.getenv("LLAMA_HF_TOKEN")
63
  if hf_token:
64
  app.state.llama_tokenizer = AutoTokenizer.from_pretrained(llama_name, token=hf_token)
65
+ app.state.llama_tokenizer.pad_token = app.state.llama_tokenizer.eos_token
66
  app.state.llama_model = AutoModelForCausalLM.from_pretrained(
67
  llama_name,
68
  token=hf_token,
 
87
  return completion.choices[0].message.content.strip()
88
 
89
  def run_deepseek(req1, req2, prompt_type):
 
90
  tokenizer = app.state.deepseek_tokenizer
91
  model = app.state.deepseek_model
 
92
  prompt = build_prompt(req1, req2, prompt_type)
93
+ inputs = tokenizer(
94
+ [prompt],
95
+ return_tensors="pt",
96
+ padding=True,
97
+ truncation=True
98
+ ).to(model.device)
99
+ outputs = model.generate(
100
+ input_ids=inputs.input_ids,
101
+ attention_mask=inputs.attention_mask,
102
+ max_new_tokens=256,
103
+ pad_token_id=tokenizer.eos_token_id
104
+ )
105
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
106
 
107
  def run_llama(req1, req2, prompt_type):
108
  tokenizer = app.state.llama_tokenizer
109
  model = app.state.llama_model
110
  prompt = build_prompt(req1, req2, prompt_type)
111
+ inputs = tokenizer(
112
+ [prompt],
113
+ return_tensors="pt",
114
+ padding=True,
115
+ truncation=True
116
+ ).to(model.device)
117
+ outputs = model.generate(
118
+ input_ids=inputs.input_ids,
119
+ attention_mask=inputs.attention_mask,
120
+ max_new_tokens=256,
121
+ pad_token_id=tokenizer.eos_token_id
122
+ )
123
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
124
 
125
  def run_fanar(req1, req2, prompt_type):
 
144
 
145
  elif request.model_choice == "DeepSeek-Reasoner":
146
  answer = run_deepseek(request.Req1, request.Req2, request.prompt_type)
147
+
148
  elif request.model_choice == "LLaMA-3.1-8B-Instruct":
149
  if not hasattr(app.state, "llama_model"):
150
  return JSONResponse({"error": "LLaMA not loaded (missing HF_TOKEN)"}, status_code=400)