Spaces:
Runtime error
Runtime error
| import os, torch, gradio as gr | |
| from typing import Optional | |
| from transformers import ( | |
| AutoTokenizer, AutoConfig, | |
| AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModelForSequenceClassification, | |
| TextClassificationPipeline, pipeline | |
| ) | |
| # --- YOUR MODELS --- | |
| HF_TRANSLATOR_MODEL = "facebook/nllb-200-distilled-600M" # seq2seq | |
| HF_AGRIPARAM_MODEL = "bharatgenai/AgriParam" # classifier or causal; we auto-detect | |
| HF_LLAMAX_MODEL = "nurfarah57/Somali-Agri-LLaMAX3-8B-Merged" # LLaMA-family chat | |
| # --- SETTINGS (override via Space Variables if you like) --- | |
| LOAD_4BIT = os.getenv("LOAD_4BIT", "1") == "1" # keep 4-bit on for small VRAM | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) | |
| TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "1") == "1" | |
| def _bnb_kwargs(): | |
| if LOAD_4BIT and torch.cuda.is_available(): | |
| from transformers import BitsAndBytesConfig | |
| return dict( | |
| quantization_config=BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True | |
| ), | |
| torch_dtype=torch.bfloat16, device_map="auto", | |
| ) | |
| return dict( | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| def _is_seq2seq(cfg: AutoConfig) -> bool: | |
| arch = (cfg.architectures or [""])[0].lower() | |
| return "seq2seq" in arch or "conditionalgeneration" in arch or "mbart" in arch or "marian" in arch or "t5" in arch | |
| def _is_causal(cfg: AutoConfig) -> bool: | |
| arch = (cfg.architectures or [""])[0].lower() | |
| return "causallm" in arch or "llama" in arch or "gpt" in arch or "mistral" in arch | |
| def _is_classifier(cfg: AutoConfig) -> bool: | |
| arch = (cfg.architectures or [""])[0].lower() | |
| return "sequenceclassification" in arch | |
| def load_any(repo_id: str): | |
| cfg = AutoConfig.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE) | |
| tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=TRUST_REMOTE_CODE) | |
| if _is_seq2seq(cfg): | |
| model = AutoModelForSeq2SeqLM.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs()) | |
| if tok.pad_token is None: tok.pad_token = tok.eos_token | |
| return ("seq2seq", tok, model) | |
| if _is_classifier(cfg): | |
| model = AutoModelForSequenceClassification.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs()) | |
| return ("classifier", tok, model) | |
| # default to causal | |
| model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs()) | |
| if tok.pad_token is None: tok.pad_token = tok.eos_token | |
| return ("causal", tok, model) | |
| # ----- Translator (NLLB-200 600M) ----- | |
| tr_type, tr_tok, tr_model = load_any(HF_TRANSLATOR_MODEL) | |
| def translate(text: str, src_code: str, tgt_code: str, temperature: float, top_p: float): | |
| if tr_type != "seq2seq": | |
| return "Translator must be a seq2seq model." | |
| # NLLB/mBART language codes e.g., eng_Latn, som_Latn | |
| forced = {} | |
| if hasattr(tr_tok, "lang_code_to_id") and tgt_code in tr_tok.lang_code_to_id: | |
| forced["forced_bos_token_id"] = tr_tok.lang_code_to_id[tgt_code] | |
| tr_tok.src_lang = src_code | |
| inputs = tr_tok(text, return_tensors="pt", padding=True, truncation=True).to(tr_model.device) | |
| with torch.inference_mode(): | |
| out = tr_model.generate( | |
| **inputs, do_sample=True, temperature=temperature, top_p=top_p, | |
| max_new_tokens=MAX_NEW_TOKENS, num_beams=1, length_penalty=1.0, **forced | |
| ) | |
| return tr_tok.decode(out[0], skip_special_tokens=True) | |
| # ----- AgriParam (auto-detect clf vs causal) ----- | |
| ap_type, ap_tok, ap_model = load_any(HF_AGRIPARAM_MODEL) | |
| ap_pipe: Optional[TextClassificationPipeline] = None | |
| if ap_type == "classifier": | |
| ap_pipe = pipeline("text-classification", model=ap_model, tokenizer=ap_tok, | |
| device=0 if torch.cuda.is_available() else -1, truncation=True) | |
| def agriparam_infer(text: str, temperature: float, top_p: float): | |
| if ap_type == "classifier": | |
| res = ap_pipe(text, return_all_scores=True)[0] | |
| res = sorted(res, key=lambda d: d["score"], reverse=True) | |
| return "\n".join([f"{r['label']}: {r['score']:.4f}" for r in res]) | |
| # treat as generator | |
| inputs = ap_tok(text, return_tensors="pt").to(ap_model.device) | |
| with torch.inference_mode(): | |
| out = ap_model.generate( | |
| **inputs, do_sample=True, temperature=temperature, top_p=top_p, | |
| max_new_tokens=MAX_NEW_TOKENS, pad_token_id=ap_tok.eos_token_id | |
| ) | |
| return ap_tok.decode(out[0], skip_special_tokens=True) | |
| # ----- LlamaX chat (8B) ----- | |
| lx_type, lx_tok, lx_model = load_any(HF_LLAMAX_MODEL) | |
| def _apply_chat_template(user_msg: str, system_prompt: str = "You are a helpful Somali agriculture assistant."): | |
| if hasattr(lx_tok, "apply_chat_template"): | |
| msgs = [{"role":"system","content":system_prompt},{"role":"user","content":user_msg}] | |
| return lx_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| return f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n[INST] {user_msg} [/INST]" | |
| def llamax_chat(user_msg: str, system_prompt: str, temperature: float, top_p: float): | |
| prompt = _apply_chat_template(user_msg, system_prompt) | |
| inputs = lx_tok(prompt, return_tensors="pt").to(lx_model.device) | |
| with torch.inference_mode(): | |
| out = lx_model.generate( | |
| **inputs, do_sample=True, temperature=temperature, top_p=top_p, | |
| max_new_tokens=MAX_NEW_TOKENS, pad_token_id=lx_tok.eos_token_id | |
| ) | |
| text = lx_tok.decode(out[0], skip_special_tokens=True) | |
| return text.replace(prompt, "").strip() | |
| # ----- Gradio UI ----- | |
| with gr.Blocks(title="Somali Agri • LlamaX + AgriParam + NLLB") as demo: | |
| gr.Markdown("### 🌾 Somali Agri Suite\n- **LlamaX 8B** chat\n- **AgriParam** (classification or generator)\n- **NLLB-200 600M** translator") | |
| with gr.Tabs(): | |
| with gr.Tab("Translator (NLLB-200)"): | |
| src = gr.Textbox(label="Source text") | |
| with gr.Row(): | |
| src_code = gr.Textbox(value="eng_Latn", label="Source language code") | |
| tgt_code = gr.Textbox(value="som_Latn", label="Target language code") | |
| t_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") | |
| t_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| t_btn = gr.Button("Translate") | |
| t_out = gr.Textbox(label="Translation", lines=6) | |
| t_btn.click(translate, [src, src_code, tgt_code, t_temp, t_topp], t_out) | |
| with gr.Tab("AgriParam"): | |
| ap_in = gr.Textbox(label="Text / instruction") | |
| ap_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") | |
| ap_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| ap_btn = gr.Button("Run") | |
| ap_out = gr.Textbox(label="Output", lines=10) | |
| ap_btn.click(agriparam_infer, [ap_in, ap_temp, ap_topp], ap_out) | |
| with gr.Tab("LlamaX Chat"): | |
| sys = gr.Textbox(value="You are a helpful Somali agriculture assistant.", label="System prompt") | |
| user = gr.Textbox(label="User message") | |
| lx_temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature") | |
| lx_topp = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") | |
| lx_btn = gr.Button("Generate") | |
| lx_out = gr.Textbox(label="Assistant", lines=12) | |
| lx_btn.click(llamax_chat, [user, sys, lx_temp, lx_topp], lx_out) | |
| gr.Markdown( | |
| f"**Loaded**:\n- Translator: `{HF_TRANSLATOR_MODEL}`\n- AgriParam: `{HF_AGRIPARAM_MODEL}`\n- LlamaX: `{HF_LLAMAX_MODEL}`\n- 4-bit quant: `{LOAD_4BIT}`" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |