File size: 7,965 Bytes
e0b4a49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os, torch, gradio as gr
from typing import Optional
from transformers import (
    AutoTokenizer, AutoConfig,
    AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModelForSequenceClassification,
    TextClassificationPipeline, pipeline
)

# --- YOUR MODELS ---
HF_TRANSLATOR_MODEL = "facebook/nllb-200-distilled-600M"          # seq2seq
HF_AGRIPARAM_MODEL  = "bharatgenai/AgriParam"                     # classifier or causal; we auto-detect
HF_LLAMAX_MODEL     = "nurfarah57/Somali-Agri-LLaMAX3-8B-Merged"  # LLaMA-family chat

# --- SETTINGS (override via Space Variables if you like) ---
LOAD_4BIT = os.getenv("LOAD_4BIT", "1") == "1"             # keep 4-bit on for small VRAM
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "1") == "1"

def _bnb_kwargs():
    if LOAD_4BIT and torch.cuda.is_available():
        from transformers import BitsAndBytesConfig
        return dict(
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True
            ),
            torch_dtype=torch.bfloat16, device_map="auto",
        )
    return dict(
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
    )

def _is_seq2seq(cfg: AutoConfig) -> bool:
    arch = (cfg.architectures or [""])[0].lower()
    return "seq2seq" in arch or "conditionalgeneration" in arch or "mbart" in arch or "marian" in arch or "t5" in arch

def _is_causal(cfg: AutoConfig) -> bool:
    arch = (cfg.architectures or [""])[0].lower()
    return "causallm" in arch or "llama" in arch or "gpt" in arch or "mistral" in arch

def _is_classifier(cfg: AutoConfig) -> bool:
    arch = (cfg.architectures or [""])[0].lower()
    return "sequenceclassification" in arch

def load_any(repo_id: str):
    cfg = AutoConfig.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE)
    tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True, trust_remote_code=TRUST_REMOTE_CODE)

    if _is_seq2seq(cfg):
        model = AutoModelForSeq2SeqLM.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs())
        if tok.pad_token is None: tok.pad_token = tok.eos_token
        return ("seq2seq", tok, model)

    if _is_classifier(cfg):
        model = AutoModelForSequenceClassification.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs())
        return ("classifier", tok, model)

    # default to causal
    model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=TRUST_REMOTE_CODE, **_bnb_kwargs())
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    return ("causal", tok, model)

# ----- Translator (NLLB-200 600M) -----
tr_type, tr_tok, tr_model = load_any(HF_TRANSLATOR_MODEL)

def translate(text: str, src_code: str, tgt_code: str, temperature: float, top_p: float):
    if tr_type != "seq2seq":
        return "Translator must be a seq2seq model."
    # NLLB/mBART language codes e.g., eng_Latn, som_Latn
    forced = {}
    if hasattr(tr_tok, "lang_code_to_id") and tgt_code in tr_tok.lang_code_to_id:
        forced["forced_bos_token_id"] = tr_tok.lang_code_to_id[tgt_code]
        tr_tok.src_lang = src_code
    inputs = tr_tok(text, return_tensors="pt", padding=True, truncation=True).to(tr_model.device)
    with torch.inference_mode():
        out = tr_model.generate(
            **inputs, do_sample=True, temperature=temperature, top_p=top_p,
            max_new_tokens=MAX_NEW_TOKENS, num_beams=1, length_penalty=1.0, **forced
        )
    return tr_tok.decode(out[0], skip_special_tokens=True)

# ----- AgriParam (auto-detect clf vs causal) -----
ap_type, ap_tok, ap_model = load_any(HF_AGRIPARAM_MODEL)
ap_pipe: Optional[TextClassificationPipeline] = None
if ap_type == "classifier":
    ap_pipe = pipeline("text-classification", model=ap_model, tokenizer=ap_tok,
                       device=0 if torch.cuda.is_available() else -1, truncation=True)

def agriparam_infer(text: str, temperature: float, top_p: float):
    if ap_type == "classifier":
        res = ap_pipe(text, return_all_scores=True)[0]
        res = sorted(res, key=lambda d: d["score"], reverse=True)
        return "\n".join([f"{r['label']}: {r['score']:.4f}" for r in res])
    # treat as generator
    inputs = ap_tok(text, return_tensors="pt").to(ap_model.device)
    with torch.inference_mode():
        out = ap_model.generate(
            **inputs, do_sample=True, temperature=temperature, top_p=top_p,
            max_new_tokens=MAX_NEW_TOKENS, pad_token_id=ap_tok.eos_token_id
        )
    return ap_tok.decode(out[0], skip_special_tokens=True)

# ----- LlamaX chat (8B) -----
lx_type, lx_tok, lx_model = load_any(HF_LLAMAX_MODEL)

def _apply_chat_template(user_msg: str, system_prompt: str = "You are a helpful Somali agriculture assistant."):
    if hasattr(lx_tok, "apply_chat_template"):
        msgs = [{"role":"system","content":system_prompt},{"role":"user","content":user_msg}]
        return lx_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    return f"<<SYS>>\n{system_prompt}\n<</SYS>>\n\n[INST] {user_msg} [/INST]"

def llamax_chat(user_msg: str, system_prompt: str, temperature: float, top_p: float):
    prompt = _apply_chat_template(user_msg, system_prompt)
    inputs = lx_tok(prompt, return_tensors="pt").to(lx_model.device)
    with torch.inference_mode():
        out = lx_model.generate(
            **inputs, do_sample=True, temperature=temperature, top_p=top_p,
            max_new_tokens=MAX_NEW_TOKENS, pad_token_id=lx_tok.eos_token_id
        )
    text = lx_tok.decode(out[0], skip_special_tokens=True)
    return text.replace(prompt, "").strip()

# ----- Gradio UI -----
with gr.Blocks(title="Somali Agri • LlamaX + AgriParam + NLLB") as demo:
    gr.Markdown("### 🌾 Somali Agri Suite\n- **LlamaX 8B** chat\n- **AgriParam** (classification or generator)\n- **NLLB-200 600M** translator")
    with gr.Tabs():
        with gr.Tab("Translator (NLLB-200)"):
            src = gr.Textbox(label="Source text")
            with gr.Row():
                src_code = gr.Textbox(value="eng_Latn", label="Source language code")
                tgt_code = gr.Textbox(value="som_Latn", label="Target language code")
            t_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
            t_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
            t_btn = gr.Button("Translate")
            t_out = gr.Textbox(label="Translation", lines=6)
            t_btn.click(translate, [src, src_code, tgt_code, t_temp, t_topp], t_out)

        with gr.Tab("AgriParam"):
            ap_in = gr.Textbox(label="Text / instruction")
            ap_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
            ap_topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
            ap_btn = gr.Button("Run")
            ap_out = gr.Textbox(label="Output", lines=10)
            ap_btn.click(agriparam_infer, [ap_in, ap_temp, ap_topp], ap_out)

        with gr.Tab("LlamaX Chat"):
            sys = gr.Textbox(value="You are a helpful Somali agriculture assistant.", label="System prompt")
            user = gr.Textbox(label="User message")
            lx_temp = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label="Temperature")
            lx_topp = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
            lx_btn = gr.Button("Generate")
            lx_out = gr.Textbox(label="Assistant", lines=12)
            lx_btn.click(llamax_chat, [user, sys, lx_temp, lx_topp], lx_out)

    gr.Markdown(
        f"**Loaded**:\n- Translator: `{HF_TRANSLATOR_MODEL}`\n- AgriParam: `{HF_AGRIPARAM_MODEL}`\n- LlamaX: `{HF_LLAMAX_MODEL}`\n- 4-bit quant: `{LOAD_4BIT}`"
    )

if __name__ == "__main__":
    demo.launch()