File size: 7,321 Bytes
0fb3268
 
 
 
 
 
 
 
 
 
7b7f8c2
0fb3268
 
 
7b7f8c2
 
0fb3268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b7f8c2
 
0fb3268
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# -*- coding: utf-8 -*-
"""
Refactored Salama Assistant: text-only chatbot (STT and TTS removed)
Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
Requirements:
- transformers
- peft
- gradio
- huggingface_hub
- torch

Notes:
- Set HF_TOKEN in env for private models or use Spaces secret.
- This keeps the LLM + PEFT adapter loading and streaming text responses into the Gradio chat UI.
"""

import os
import threading
import gradio as gr
import torch
from huggingface_hub import login
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    TextIteratorStreamer,
)
from peft import PeftModel, PeftConfig

# -------------------- Configuration --------------------
ADAPTER_REPO_ID = "EYEDOL/Llama-3.2-3b_ON_ALPACA5"  # adapter-only repo
BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct"    # full base model referenced by adapter

HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
        print("Successfully logged into Hugging Face Hub!")
    except Exception as e:
        print("Warning: huggingface_hub.login() failed:", e)
else:
    print("Warning: HF_TOKEN not found in env. Private repos may fail to load.")


class WeeboAssistant:
    def __init__(self):
        self.SYSTEM_PROMPT = (
            "Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa KWA UFUPI na kwa usahihi. "
            "Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu.\n"
        )
        self._init_models()

    def _init_models(self):
        print("Initializing models...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
        print(f"Using device: {self.device}")

        # 1) Tokenizer (prefer base tokenizer)
        try:
            self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
        except Exception as e:
            print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e)
            self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)

        # 2) Load base model
        device_map = "auto" if torch.cuda.is_available() else None
        try:
            self.llm_model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL_ID,
                torch_dtype=self.torch_dtype,
                low_cpu_mem_usage=True,
                device_map=device_map,
                trust_remote_code=True,
            )
        except Exception as e:
            raise RuntimeError(
                "Failed to load base model. Ensure the base model ID is correct and the HF_TOKEN has access if private. Error: "
                + str(e)
            )

        # 3) Load and apply PEFT adapter (adapter-only repo)
        try:
            peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
            self.llm_model = PeftModel.from_pretrained(
                self.llm_model,
                ADAPTER_REPO_ID,
                device_map=device_map,
                torch_dtype=self.torch_dtype,
                low_cpu_mem_usage=True,
            )
        except Exception as e:
            raise RuntimeError(
                "Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files are present and HF_TOKEN has access if private. Error: "
                + str(e)
            )

        # 4) Optional non-streaming pipeline (useful for small tests)
        try:
            device_index = 0 if torch.cuda.is_available() else -1
            self.llm_pipeline = pipeline(
                "text-generation",
                model=self.llm_model,
                tokenizer=self.llm_tokenizer,
                device=device_index,
                model_kwargs={"torch_dtype": self.torch_dtype},
            )
        except Exception as e:
            print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e)
            self.llm_pipeline = None

        print("LLM base + adapter loaded successfully.")

    def get_llm_response(self, chat_history):
        # Build prompt from system + conversation history
        prompt_lines = [self.SYSTEM_PROMPT]
        for user_msg, assistant_msg in chat_history:
            if user_msg:
                prompt_lines.append("User: " + user_msg)
            if assistant_msg:
                prompt_lines.append("Assistant: " + assistant_msg)
        prompt_lines.append("Assistant: ")
        prompt = "\n".join(prompt_lines)

        inputs = self.llm_tokenizer(prompt, return_tensors="pt")
        try:
            model_device = next(self.llm_model.parameters()).device
        except StopIteration:
            model_device = torch.device("cpu")
        inputs = {k: v.to(model_device) for k, v in inputs.items()}

        streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)

        generation_kwargs = dict(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask", None),
            max_new_tokens=512,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            streamer=streamer,
            eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
        )

        gen_thread = threading.Thread(target=self.llm_model.generate, kwargs=generation_kwargs, daemon=True)
        gen_thread.start()

        return streamer


# -------------------- Create assistant instance --------------------
assistant = WeeboAssistant()


# -------------------- Gradio pipelines --------------------
def t2t_pipeline(text_input, chat_history):
    # Append the user's message and stream the assistant reply
    chat_history.append((text_input, ""))
    yield chat_history

    response_stream = assistant.get_llm_response(chat_history)
    llm_response_text = ""
    for text_chunk in response_stream:
        llm_response_text += text_chunk
        chat_history[-1] = (text_input, llm_response_text)
        yield chat_history


def clear_textbox():
    return gr.Textbox.update(value="")


# -------------------- Gradio UI --------------------
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili - Text Chat") as demo:
    gr.Markdown("# 🤖 Msaidizi wa Kiswahili (Text Chat)")
    gr.Markdown("Ongea (aina ya maandishi) na msaidizi kwa Kiswahili. Tumia kisanduku kifuatacho kuandika.")

    t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
    with gr.Row():
        t2t_text_in = gr.Textbox(show_label=False, placeholder="Andika hapa...", scale=4, container=False)
        t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)

    t2t_submit_btn.click(
        fn=t2t_pipeline,
        inputs=[t2t_text_in, t2t_chatbot],
        outputs=[t2t_chatbot],
        queue=True,
    ).then(
        fn=clear_textbox,
        inputs=None,
        outputs=t2t_text_in,
    )

    t2t_text_in.submit(
        fn=t2t_pipeline,
        inputs=[t2t_text_in, t2t_chatbot],
        outputs=[t2t_chatbot],
        queue=True,
    ).then(
        fn=clear_textbox,
        inputs=None,
        outputs=t2t_text_in,
    )


demo.queue().launch(debug=True)