final_test / app.py
nurfarah57's picture
Create app.py
e0b4a49 verified
import os, torch, gradio as gr
from typing import Optional
from transformers import (
AutoTokenizer, AutoConfig,
AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModelForSequenceClassification,
TextClassificationPipeline, pipeline
)
# --- YOUR MODELS ---
HF_TRANSLATOR_MODEL = "facebook/nllb-200-distilled-600M" # seq2seq
HF_AGRIPARAM_MODEL = "bharatgenai/AgriParam" # classifier or causal; we auto-detect
HF_LLAMAX_MODEL = "nurfarah57/Somali-Agri-LLaMAX3-8B-Merged" # LLaMA-family chat
# --- SETTINGS (override via Space Variables if you like) ---
LOAD_4BIT = os.getenv("LOAD_4BIT", "1") == "1" # keep 4-bit on for small VRAM
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "1") == "1"
def _bnb_kwargs():
if LOAD_4BIT and torch.cuda.is_available():
from transformers import BitsAndBytesConfig
return dict(
quantization_config=BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True
),
torch_dtype=torch.bfloat16, device_map="auto",
)
return dict(
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
def _is_seq2seq(cfg: AutoConfig) -> bool:
arch = (cfg.architectures or [""])[0].lower()
return "seq2seq" in arch or "conditionalgeneration" in arch or "mbart" in arch or "marian" in arch or "t5" in arch
def _is_causal(cfg: AutoConfig) -> bool:
arch = (cfg.architectures or [""])[0].lower()
return "causallm" in arch or "llama" in arch or "gpt" in arch or "mistral" in arch
def _is_classifier(cfg: AutoConfig) -> bool:
arch = (cfg.architectures or [""])[0].lower()
return "sequenceclassification" in arch
def load_any(repo_id: str):
cfg = AutoConfig.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE)
tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=TRUST_REMOTE_CODE)
if _is_seq2seq(cfg):
model = AutoModelForSeq2SeqLM.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs())
if tok.pad_token is None: tok.pad_token = tok.eos_token
return ("seq2seq", tok, model)
if _is_classifier(cfg):
model = AutoModelForSequenceClassification.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs())
return ("classifier", tok, model)
# default to causal
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs())
if tok.pad_token is None: tok.pad_token = tok.eos_token
return ("causal", tok, model)
# ----- Translator (NLLB-200 600M) -----
tr_type, tr_tok, tr_model = load_any(HF_TRANSLATOR_MODEL)
def translate(text: str, src_code: str, tgt_code: str, temperature: float, top_p: float):
if tr_type != "seq2seq":
return "Translator must be a seq2seq model."
# NLLB/mBART language codes e.g., eng_Latn, som_Latn
forced = {}
if hasattr(tr_tok, "lang_code_to_id") and tgt_code in tr_tok.lang_code_to_id:
forced["forced_bos_token_id"] = tr_tok.lang_code_to_id[tgt_code]
tr_tok.src_lang = src_code
inputs = tr_tok(text, return_tensors="pt", padding=True, truncation=True).to(tr_model.device)
with torch.inference_mode():
out = tr_model.generate(
**inputs, do_sample=True, temperature=temperature, top_p=top_p,
max_new_tokens=MAX_NEW_TOKENS, num_beams=1, length_penalty=1.0, **forced
)
return tr_tok.decode(out[0], skip_special_tokens=True)
# ----- AgriParam (auto-detect clf vs causal) -----
ap_type, ap_tok, ap_model = load_any(HF_AGRIPARAM_MODEL)
ap_pipe: Optional[TextClassificationPipeline] = None
if ap_type == "classifier":
ap_pipe = pipeline("text-classification", model=ap_model, tokenizer=ap_tok,
device=0 if torch.cuda.is_available() else -1, truncation=True)
def agriparam_infer(text: str, temperature: float, top_p: float):
if ap_type == "classifier":
res = ap_pipe(text, return_all_scores=True)[0]
res = sorted(res, key=lambda d: d["score"], reverse=True)
return "\n".join([f"{r['label']}: {r['score']:.4f}" for r in res])
# treat as generator
inputs = ap_tok(text, return_tensors="pt").to(ap_model.device)
with torch.inference_mode():
out = ap_model.generate(
**inputs, do_sample=True, temperature=temperature, top_p=top_p,
max_new_tokens=MAX_NEW_TOKENS, pad_token_id=ap_tok.eos_token_id
)
return ap_tok.decode(out[0], skip_special_tokens=True)
# ----- LlamaX chat (8B) -----
lx_type, lx_tok, lx_model = load_any(HF_LLAMAX_MODEL)
def _apply_chat_template(user_msg: str, system_prompt: str = "You are a helpful Somali agriculture assistant."):
if hasattr(lx_tok, "apply_chat_template"):
msgs = [{"role":"system","content":system_prompt},{"role":"user","content":user_msg}]
return lx_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
return f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n[INST] {user_msg} [/INST]"
def llamax_chat(user_msg: str, system_prompt: str, temperature: float, top_p: float):
prompt = _apply_chat_template(user_msg, system_prompt)
inputs = lx_tok(prompt, return_tensors="pt").to(lx_model.device)
with torch.inference_mode():
out = lx_model.generate(
**inputs, do_sample=True, temperature=temperature, top_p=top_p,
max_new_tokens=MAX_NEW_TOKENS, pad_token_id=lx_tok.eos_token_id
)
text = lx_tok.decode(out[0], skip_special_tokens=True)
return text.replace(prompt, "").strip()
# ----- Gradio UI -----
with gr.Blocks(title="Somali Agri • LlamaX + AgriParam + NLLB") as demo:
gr.Markdown("### 🌾 Somali Agri Suite\n- **LlamaX 8B** chat\n- **AgriParam** (classification or generator)\n- **NLLB-200 600M** translator")
with gr.Tabs():
with gr.Tab("Translator (NLLB-200)"):
src = gr.Textbox(label="Source text")
with gr.Row():
src_code = gr.Textbox(value="eng_Latn", label="Source language code")
tgt_code = gr.Textbox(value="som_Latn", label="Target language code")
t_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
t_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
t_btn = gr.Button("Translate")
t_out = gr.Textbox(label="Translation", lines=6)
t_btn.click(translate, [src, src_code, tgt_code, t_temp, t_topp], t_out)
with gr.Tab("AgriParam"):
ap_in = gr.Textbox(label="Text / instruction")
ap_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
ap_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
ap_btn = gr.Button("Run")
ap_out = gr.Textbox(label="Output", lines=10)
ap_btn.click(agriparam_infer, [ap_in, ap_temp, ap_topp], ap_out)
with gr.Tab("LlamaX Chat"):
sys = gr.Textbox(value="You are a helpful Somali agriculture assistant.", label="System prompt")
user = gr.Textbox(label="User message")
lx_temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature")
lx_topp = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
lx_btn = gr.Button("Generate")
lx_out = gr.Textbox(label="Assistant", lines=12)
lx_btn.click(llamax_chat, [user, sys, lx_temp, lx_topp], lx_out)
gr.Markdown(
f"**Loaded**:\n- Translator: `{HF_TRANSLATOR_MODEL}`\n- AgriParam: `{HF_AGRIPARAM_MODEL}`\n- LlamaX: `{HF_LLAMAX_MODEL}`\n- 4-bit quant: `{LOAD_4BIT}`"
)
if __name__ == "__main__":
demo.launch()