Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,10 @@ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
|
| 7 |
os.environ["WANDB_DISABLED"] = "true"
|
| 8 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
print("CUDA available:", torch.cuda.is_available())
|
| 11 |
if torch.cuda.is_available():
|
| 12 |
print("GPU:", torch.cuda.get_device_name(0), "| CC:", torch.cuda.get_device_capability(0))
|
|
@@ -89,7 +93,8 @@ def _nllb_translate(text: str, src_code: str, tgt_code: str, max_new: int = 256)
|
|
| 89 |
forced_bos_token_id=tgt_id,
|
| 90 |
max_new_tokens=max_new,
|
| 91 |
do_sample=False,
|
| 92 |
-
num_beams=4,
|
|
|
|
| 93 |
repetition_penalty=1.02,
|
| 94 |
)
|
| 95 |
return mt_tok.batch_decode(out, skip_special_tokens=True)[0].strip()
|
|
@@ -170,7 +175,7 @@ def _restore_terms(s: str, placeholders: dict):
|
|
| 170 |
return s
|
| 171 |
|
| 172 |
@torch.inference_mode()
|
| 173 |
-
def agri_answer_bullets_en(q_en: str, so_original: str, crop_hint: str = "", max_new=
|
| 174 |
crop = f" Crop: {crop_hint}." if crop_hint else ""
|
| 175 |
prompt = f"<context> Somali original: {so_original}.{crop} {AGRI_HINT} <user> {q_en} <assistant>"
|
| 176 |
enc = ag_tok(prompt, return_tensors="pt")
|
|
@@ -229,7 +234,8 @@ def answer_router(question_so: str) -> str:
|
|
| 229 |
en_q_raw = so_to_en(question_so)
|
| 230 |
en_q = polish_en_question(en_q_raw, crop)
|
| 231 |
# Ask the agri model
|
| 232 |
-
bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=200)
|
|
|
|
| 233 |
if not bullets_en:
|
| 234 |
return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)."
|
| 235 |
# Back to Somali
|
|
|
|
| 7 |
os.environ["WANDB_DISABLED"] = "true"
|
| 8 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 9 |
|
| 10 |
+
torch.set_num_threads(2) # or 1 on the smallest CPU tier
|
| 11 |
+
os.environ["OMP_NUM_THREADS"] = "2"
|
| 12 |
+
os.environ["MKL_NUM_THREADS"] = "2"
|
| 13 |
+
|
| 14 |
print("CUDA available:", torch.cuda.is_available())
|
| 15 |
if torch.cuda.is_available():
|
| 16 |
print("GPU:", torch.cuda.get_device_name(0), "| CC:", torch.cuda.get_device_capability(0))
|
|
|
|
| 93 |
forced_bos_token_id=tgt_id,
|
| 94 |
max_new_tokens=max_new,
|
| 95 |
do_sample=False,
|
| 96 |
+
# num_beams=4,
|
| 97 |
+
num_beams=1,
|
| 98 |
repetition_penalty=1.02,
|
| 99 |
)
|
| 100 |
return mt_tok.batch_decode(out, skip_special_tokens=True)[0].strip()
|
|
|
|
| 175 |
return s
|
| 176 |
|
| 177 |
@torch.inference_mode()
|
| 178 |
+
def agri_answer_bullets_en(q_en: str, so_original: str, crop_hint: str = "", max_new=80) -> list[str]:
|
| 179 |
crop = f" Crop: {crop_hint}." if crop_hint else ""
|
| 180 |
prompt = f"<context> Somali original: {so_original}.{crop} {AGRI_HINT} <user> {q_en} <assistant>"
|
| 181 |
enc = ag_tok(prompt, return_tensors="pt")
|
|
|
|
| 234 |
en_q_raw = so_to_en(question_so)
|
| 235 |
en_q = polish_en_question(en_q_raw, crop)
|
| 236 |
# Ask the agri model
|
| 237 |
+
# bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=200)
|
| 238 |
+
bullets_en = agri_answer_bullets_en(en_q, question_so, crop_hint=crop, max_new=80)
|
| 239 |
if not bullets_en:
|
| 240 |
return "Ma hubo. Fadlan ku celi su’aasha si kooban (dalagga, dhibaatada, iyo meesha)."
|
| 241 |
# Back to Somali
|