import os, torch, gradio as gr from typing import Optional from transformers import ( AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModelForSequenceClassification, TextClassificationPipeline, pipeline ) # --- YOUR MODELS --- HF_TRANSLATOR_MODEL = "facebook/nllb-200-distilled-600M" # seq2seq HF_AGRIPARAM_MODEL = "bharatgenai/AgriParam" # classifier or causal; we auto-detect HF_LLAMAX_MODEL = "nurfarah57/Somali-Agri-LLaMAX3-8B-Merged" # LLaMA-family chat # --- SETTINGS (override via Space Variables if you like) --- LOAD_4BIT = os.getenv("LOAD_4BIT", "1") == "1" # keep 4-bit on for small VRAM MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "1") == "1" def _bnb_kwargs(): if LOAD_4BIT and torch.cuda.is_available(): from transformers import BitsAndBytesConfig return dict( quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ), torch_dtype=torch.bfloat16, device_map="auto", ) return dict( torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) def _is_seq2seq(cfg: AutoConfig) -> bool: arch = (cfg.architectures or [""])[0].lower() return "seq2seq" in arch or "conditionalgeneration" in arch or "mbart" in arch or "marian" in arch or "t5" in arch def _is_causal(cfg: AutoConfig) -> bool: arch = (cfg.architectures or [""])[0].lower() return "causallm" in arch or "llama" in arch or "gpt" in arch or "mistral" in arch def _is_classifier(cfg: AutoConfig) -> bool: arch = (cfg.architectures or [""])[0].lower() return "sequenceclassification" in arch def load_any(repo_id: str): cfg = AutoConfig.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE) tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=TRUST_REMOTE_CODE) if _is_seq2seq(cfg): model = AutoModelForSeq2SeqLM.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs()) if tok.pad_token is None: tok.pad_token = tok.eos_token return ("seq2seq", tok, model) if _is_classifier(cfg): model = AutoModelForSequenceClassification.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs()) return ("classifier", tok, model) # default to causal model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs()) if tok.pad_token is None: tok.pad_token = tok.eos_token return ("causal", tok, model) # ----- Translator (NLLB-200 600M) ----- tr_type, tr_tok, tr_model = load_any(HF_TRANSLATOR_MODEL) def translate(text: str, src_code: str, tgt_code: str, temperature: float, top_p: float): if tr_type != "seq2seq": return "Translator must be a seq2seq model." # NLLB/mBART language codes e.g., eng_Latn, som_Latn forced = {} if hasattr(tr_tok, "lang_code_to_id") and tgt_code in tr_tok.lang_code_to_id: forced["forced_bos_token_id"] = tr_tok.lang_code_to_id[tgt_code] tr_tok.src_lang = src_code inputs = tr_tok(text, return_tensors="pt", padding=True, truncation=True).to(tr_model.device) with torch.inference_mode(): out = tr_model.generate( **inputs, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=MAX_NEW_TOKENS, num_beams=1, length_penalty=1.0, **forced ) return tr_tok.decode(out[0], skip_special_tokens=True) # ----- AgriParam (auto-detect clf vs causal) ----- ap_type, ap_tok, ap_model = load_any(HF_AGRIPARAM_MODEL) ap_pipe: Optional[TextClassificationPipeline] = None if ap_type == "classifier": ap_pipe = pipeline("text-classification", model=ap_model, tokenizer=ap_tok, device=0 if torch.cuda.is_available() else -1, truncation=True) def agriparam_infer(text: str, temperature: float, top_p: float): if ap_type == "classifier": res = ap_pipe(text, return_all_scores=True)[0] res = sorted(res, key=lambda d: d["score"], reverse=True) return "\n".join([f"{r['label']}: {r['score']:.4f}" for r in res]) # treat as generator inputs = ap_tok(text, return_tensors="pt").to(ap_model.device) with torch.inference_mode(): out = ap_model.generate( **inputs, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=MAX_NEW_TOKENS, pad_token_id=ap_tok.eos_token_id ) return ap_tok.decode(out[0], skip_special_tokens=True) # ----- LlamaX chat (8B) ----- lx_type, lx_tok, lx_model = load_any(HF_LLAMAX_MODEL) def _apply_chat_template(user_msg: str, system_prompt: str = "You are a helpful Somali agriculture assistant."): if hasattr(lx_tok, "apply_chat_template"): msgs = [{"role":"system","content":system_prompt},{"role":"user","content":user_msg}] return lx_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) return f"<>\n{system_prompt}\n<>\n\n[INST] {user_msg} [/INST]" def llamax_chat(user_msg: str, system_prompt: str, temperature: float, top_p: float): prompt = _apply_chat_template(user_msg, system_prompt) inputs = lx_tok(prompt, return_tensors="pt").to(lx_model.device) with torch.inference_mode(): out = lx_model.generate( **inputs, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=MAX_NEW_TOKENS, pad_token_id=lx_tok.eos_token_id ) text = lx_tok.decode(out[0], skip_special_tokens=True) return text.replace(prompt, "").strip() # ----- Gradio UI ----- with gr.Blocks(title="Somali Agri • LlamaX + AgriParam + NLLB") as demo: gr.Markdown("### 🌾 Somali Agri Suite\n- **LlamaX 8B** chat\n- **AgriParam** (classification or generator)\n- **NLLB-200 600M** translator") with gr.Tabs(): with gr.Tab("Translator (NLLB-200)"): src = gr.Textbox(label="Source text") with gr.Row(): src_code = gr.Textbox(value="eng_Latn", label="Source language code") tgt_code = gr.Textbox(value="som_Latn", label="Target language code") t_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") t_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") t_btn = gr.Button("Translate") t_out = gr.Textbox(label="Translation", lines=6) t_btn.click(translate, [src, src_code, tgt_code, t_temp, t_topp], t_out) with gr.Tab("AgriParam"): ap_in = gr.Textbox(label="Text / instruction") ap_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") ap_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") ap_btn = gr.Button("Run") ap_out = gr.Textbox(label="Output", lines=10) ap_btn.click(agriparam_infer, [ap_in, ap_temp, ap_topp], ap_out) with gr.Tab("LlamaX Chat"): sys = gr.Textbox(value="You are a helpful Somali agriculture assistant.", label="System prompt") user = gr.Textbox(label="User message") lx_temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature") lx_topp = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") lx_btn = gr.Button("Generate") lx_out = gr.Textbox(label="Assistant", lines=12) lx_btn.click(llamax_chat, [user, sys, lx_temp, lx_topp], lx_out) gr.Markdown( f"**Loaded**:\n- Translator: `{HF_TRANSLATOR_MODEL}`\n- AgriParam: `{HF_AGRIPARAM_MODEL}`\n- LlamaX: `{HF_LLAMAX_MODEL}`\n- 4-bit quant: `{LOAD_4BIT}`" ) if __name__ == "__main__": demo.launch()