CHATBOT / src /streamlit_app.py
EYEDOL's picture
Update src/streamlit_app.py
0fb3268 verified
# -*- coding: utf-8 -*-
"""
Refactored Salama Assistant: text-only chatbot (STT and TTS removed)
Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
Requirements:
- transformers
- peft
- gradio
- huggingface_hub
- torch
Notes:
- Set HF_TOKEN in env for private models or use Spaces secret.
- This keeps the LLM + PEFT adapter loading and streaming text responses into the Gradio chat UI.
"""
import os
import threading
import gradio as gr
import torch
from huggingface_hub import login
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
TextIteratorStreamer,
)
from peft import PeftModel, PeftConfig
# -------------------- Configuration --------------------
ADAPTER_REPO_ID = "EYEDOL/Llama-3.2-3b_ON_ALPACA5" # adapter-only repo
BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct" # full base model referenced by adapter
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
print("Successfully logged into Hugging Face Hub!")
except Exception as e:
print("Warning: huggingface_hub.login() failed:", e)
else:
print("Warning: HF_TOKEN not found in env. Private repos may fail to load.")
class WeeboAssistant:
def __init__(self):
self.SYSTEM_PROMPT = (
"Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa KWA UFUPI na kwa usahihi. "
"Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu.\n"
)
self._init_models()
def _init_models(self):
print("Initializing models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
print(f"Using device: {self.device}")
# 1) Tokenizer (prefer base tokenizer)
try:
self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
except Exception as e:
print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e)
self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
# 2) Load base model
device_map = "auto" if torch.cuda.is_available() else None
try:
self.llm_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
device_map=device_map,
trust_remote_code=True,
)
except Exception as e:
raise RuntimeError(
"Failed to load base model. Ensure the base model ID is correct and the HF_TOKEN has access if private. Error: "
+ str(e)
)
# 3) Load and apply PEFT adapter (adapter-only repo)
try:
peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
self.llm_model = PeftModel.from_pretrained(
self.llm_model,
ADAPTER_REPO_ID,
device_map=device_map,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
)
except Exception as e:
raise RuntimeError(
"Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files are present and HF_TOKEN has access if private. Error: "
+ str(e)
)
# 4) Optional non-streaming pipeline (useful for small tests)
try:
device_index = 0 if torch.cuda.is_available() else -1
self.llm_pipeline = pipeline(
"text-generation",
model=self.llm_model,
tokenizer=self.llm_tokenizer,
device=device_index,
model_kwargs={"torch_dtype": self.torch_dtype},
)
except Exception as e:
print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e)
self.llm_pipeline = None
print("LLM base + adapter loaded successfully.")
def get_llm_response(self, chat_history):
# Build prompt from system + conversation history
prompt_lines = [self.SYSTEM_PROMPT]
for user_msg, assistant_msg in chat_history:
if user_msg:
prompt_lines.append("User: " + user_msg)
if assistant_msg:
prompt_lines.append("Assistant: " + assistant_msg)
prompt_lines.append("Assistant: ")
prompt = "\n".join(prompt_lines)
inputs = self.llm_tokenizer(prompt, return_tensors="pt")
try:
model_device = next(self.llm_model.parameters()).device
except StopIteration:
model_device = torch.device("cpu")
inputs = {k: v.to(model_device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
max_new_tokens=512,
do_sample=True,
temperature=0.6,
top_p=0.9,
streamer=streamer,
eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
)
gen_thread = threading.Thread(target=self.llm_model.generate, kwargs=generation_kwargs, daemon=True)
gen_thread.start()
return streamer
# -------------------- Create assistant instance --------------------
assistant = WeeboAssistant()
# -------------------- Gradio pipelines --------------------
def t2t_pipeline(text_input, chat_history):
# Append the user's message and stream the assistant reply
chat_history.append((text_input, ""))
yield chat_history
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text += text_chunk
chat_history[-1] = (text_input, llm_response_text)
yield chat_history
def clear_textbox():
return gr.Textbox.update(value="")
# -------------------- Gradio UI --------------------
with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili - Text Chat") as demo:
gr.Markdown("# 🤖 Msaidizi wa Kiswahili (Text Chat)")
gr.Markdown("Ongea (aina ya maandishi) na msaidizi kwa Kiswahili. Tumia kisanduku kifuatacho kuandika.")
t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500)
with gr.Row():
t2t_text_in = gr.Textbox(show_label=False, placeholder="Andika hapa...", scale=4, container=False)
t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1)
t2t_submit_btn.click(
fn=t2t_pipeline,
inputs=[t2t_text_in, t2t_chatbot],
outputs=[t2t_chatbot],
queue=True,
).then(
fn=clear_textbox,
inputs=None,
outputs=t2t_text_in,
)
t2t_text_in.submit(
fn=t2t_pipeline,
inputs=[t2t_text_in, t2t_chatbot],
outputs=[t2t_chatbot],
queue=True,
).then(
fn=clear_textbox,
inputs=None,
outputs=t2t_text_in,
)
demo.queue().launch(debug=True)