nurfarah57 commited on
Commit
8584a8b
·
verified ·
1 Parent(s): 4690a6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -7,6 +7,10 @@ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
7
  os.environ["WANDB_DISABLED"] = "true"
8
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
 
 
 
 
 
10
  print("CUDA available:", torch.cuda.is_available())
11
  if torch.cuda.is_available():
12
  print("GPU:", torch.cuda.get_device_name(0), "| CC:", torch.cuda.get_device_capability(0))
@@ -89,7 +93,8 @@ def _nllb_translate(text: str, src_code: str, tgt_code: str, max_new: int = 256)
89
  forced_bos_token_id=tgt_id,
90
  max_new_tokens=max_new,
91
  do_sample=False,
92
- num_beams=4,
 
93
  repetition_penalty=1.02,
94
  )
95
  return mt_tok.batch_decode(out, skip_special_tokens=True)[0].strip()
@@ -170,7 +175,7 @@ def _restore_terms(s: str, placeholders: dict):
170
  return s
171
 
172
  @torch.inference_mode()
173
- def agri_answer_bullets_en(q_en: str, so_original: str, crop_hint: str = "", max_new=200) -> list[str]:
174
  crop = f" Crop: {crop_hint}." if crop_hint else ""
175
  prompt = f"<context> Somali original: {so_original}.{crop} {AGRI_HINT} <user> {q_en} <assistant>"
176
  enc = ag_tok(prompt, return_tensors="pt")
@@ -229,7 +234,8 @@ def answer_router(question_so: str) -> str:
229
  en_q_raw = so_to_en(question_so)
230
  en_q = polish_en_question(en_q_raw, crop)
231
  # Ask the agri model
232
- bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=200)
 
233
  if not bullets_en:
234
  return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)."
235
  # Back to Somali
 
7
  os.environ["WANDB_DISABLED"] = "true"
8
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
 
10
+ torch.set_num_threads(2) # or 1 on the smallest CPU tier
11
+ os.environ["OMP_NUM_THREADS"] = "2"
12
+ os.environ["MKL_NUM_THREADS"] = "2"
13
+
14
  print("CUDA available:", torch.cuda.is_available())
15
  if torch.cuda.is_available():
16
  print("GPU:", torch.cuda.get_device_name(0), "| CC:", torch.cuda.get_device_capability(0))
 
93
  forced_bos_token_id=tgt_id,
94
  max_new_tokens=max_new,
95
  do_sample=False,
96
+ # num_beams=4,
97
+ num_beams=1,
98
  repetition_penalty=1.02,
99
  )
100
  return mt_tok.batch_decode(out, skip_special_tokens=True)[0].strip()
 
175
  return s
176
 
177
  @torch.inference_mode()
178
+ def agri_answer_bullets_en(q_en: str, so_original: str, crop_hint: str = "", max_new=80) -> list[str]:
179
  crop = f" Crop: {crop_hint}." if crop_hint else ""
180
  prompt = f"<context> Somali original: {so_original}.{crop} {AGRI_HINT} <user> {q_en} <assistant>"
181
  enc = ag_tok(prompt, return_tensors="pt")
 
234
  en_q_raw = so_to_en(question_so)
235
  en_q = polish_en_question(en_q_raw, crop)
236
  # Ask the agri model
237
+ # bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=200)
238
+ bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=80)
239
  if not bullets_en:
240
  return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)."
241
  # Back to Somali