Update app.py
Browse files
app.py
CHANGED
|
@@ -1,23 +1,18 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
from langchain_text_splitters import CharacterTextSplitter
|
| 5 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 6 |
from langchain_community.vectorstores import FAISS
|
| 7 |
-
# from langchain.chains import RetrievalQA # Not used in this RAG implementation
|
| 8 |
|
| 9 |
# --- Configuration ---
|
| 10 |
-
|
| 11 |
-
|
| 12 |
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 13 |
-
TRANSCRIPT_FILE = "nurse_toto_episode_1_transcript.md"
|
| 14 |
|
| 15 |
-
# --- Transcript Data
|
| 16 |
-
# The full transcript is loaded here. In a real scenario, this would be loaded from a file.
|
| 17 |
-
# For simplicity and deployment, we'll embed the content directly.
|
| 18 |
NURSE_TOTO_TRANSCRIPT = """
|
| 19 |
# A Nurse Toto - Episode 1: Mzee wa Kutahirii (Kiswahili Transcript)
|
| 20 |
-
|
| 21 |
**Series:** A Nurse Toto
|
| 22 |
**Episode:** 1 - Mzee wa Kutahirii
|
| 23 |
**Creator:** Eddie Butita
|
|
@@ -81,7 +76,7 @@ NURSE_TOTO_TRANSCRIPT = """
|
|
| 81 |
**Maryanne:** Mzee, unajua unasumbua wewe? Hebu keti hapo. Utalipa 500 ya registration, utaona daktari na 1,000, alafu 15k, hiyo ni ya circumcision.
|
| 82 |
**Casypool:** Silipi kitu, niko na insurance.
|
| 83 |
**Maryanne:** Ni sawa, uko na insurance. But sasa sijui kama insurance inakava wazee wa umri yako kutahiri. Utangoja hapo usikie kama watakubali.
|
| 84 |
-
**
|
| 85 |
**Maryanne:** Mzee, lakini vitu zingine ni za kujisimamia. Hizi ni aibu gani za ati, "Oh, mzee wa 52 years, circumcision na NHIF." Surely. Surely.
|
| 86 |
|
| 87 |
---
|
|
@@ -153,148 +148,100 @@ NURSE_TOTO_TRANSCRIPT = """
|
|
| 153 |
**Sly:** Ndio maana ulikuwa unasema tungoje, sindio?
|
| 154 |
"""
|
| 155 |
|
| 156 |
-
# ---
|
| 157 |
-
# Global variables to hold the model and RAG chain
|
| 158 |
tokenizer = None
|
| 159 |
model = None
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def setup_rag_chain():
|
| 163 |
-
"""Initializes the LLM, tokenizer, and RAG chain."""
|
| 164 |
-
global tokenizer, model, rag_chain
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
|
|
|
| 168 |
|
| 169 |
-
# 1. Load the Swahili LLM (using a smaller model for deployment)
|
| 170 |
-
# Note: For a free Hugging Face Space, a small model is necessary.
|
| 171 |
-
# The UlizaLlama3 is 8B and will likely require a paid GPU.
|
| 172 |
-
# We will use a placeholder for the code, but advise the user.
|
| 173 |
try:
|
| 174 |
print(f"Loading tokenizer and model: {MODEL_NAME}...")
|
| 175 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
model = AutoModelForCausalLM.from_pretrained(
|
| 178 |
MODEL_NAME,
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
device_map="auto"
|
| 182 |
)
|
|
|
|
| 183 |
print("Model loaded successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
except Exception as e:
|
| 185 |
-
print(f"
|
| 186 |
-
#
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
)
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
# 4. Setup the RAG chain
|
| 209 |
-
# We'll use a simple pipeline for generation and integrate it with the retriever manually
|
| 210 |
-
# to avoid complex LangChain dependencies that might fail on a free Space.
|
| 211 |
-
|
| 212 |
-
# A simple function to format the prompt for the LLM
|
| 213 |
-
def format_prompt(context, question):
|
| 214 |
-
# This is a general instruction prompt for the LLM
|
| 215 |
-
system_prompt = (
|
| 216 |
-
"Wewe ni mtaalamu wa mazungumzo ya Kiswahili na Sheng. "
|
| 217 |
-
"Jibu maswali ya mtumiaji kwa kutumia muktadha uliotolewa kutoka kwa "
|
| 218 |
-
"maandishi ya 'A Nurse Toto' Episode 1. Ikiwa jibu halipatikani kwenye "
|
| 219 |
-
"muktadha, jibu kwa heshima kwamba huna habari hiyo, lakini bado "
|
| 220 |
-
"tumia lugha ya Kiswahili au Sheng."
|
| 221 |
-
)
|
| 222 |
-
return f"{system_prompt}\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:"
|
| 223 |
-
|
| 224 |
-
# A simple function to run the RAG process
|
| 225 |
-
def rag_qa(question):
|
| 226 |
-
# 1. Retrieve context
|
| 227 |
-
docs = retriever.get_relevant_documents(question)
|
| 228 |
-
context = "\n---\n".join([doc.page_content for doc in docs])
|
| 229 |
-
|
| 230 |
-
# 2. Format prompt
|
| 231 |
-
prompt = format_prompt(context, question)
|
| 232 |
-
|
| 233 |
-
# 3. Generate response
|
| 234 |
-
# Using the Hugging Face pipeline for text generation
|
| 235 |
-
pipe = pipeline(
|
| 236 |
-
"text-generation",
|
| 237 |
-
model=model,
|
| 238 |
-
tokenizer=tokenizer,
|
| 239 |
max_new_tokens=256,
|
| 240 |
do_sample=True,
|
| 241 |
temperature=0.7,
|
| 242 |
-
top_p=0.9
|
| 243 |
)
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
if "Answer:" in output:
|
| 250 |
-
answer = output.split("Answer:", 1)[-1].strip()
|
| 251 |
-
else:
|
| 252 |
-
answer = output.split(prompt, 1)[-1].strip() # Fallback
|
| 253 |
-
|
| 254 |
-
return answer
|
| 255 |
-
|
| 256 |
-
rag_chain = rag_qa
|
| 257 |
-
print("RAG chain initialized.")
|
| 258 |
-
|
| 259 |
-
# --- Gradio Interface ---
|
| 260 |
-
|
| 261 |
-
def chat_function(message, history):
|
| 262 |
-
"""The main function for the Gradio chat interface."""
|
| 263 |
-
if rag_chain is None:
|
| 264 |
-
# Attempt to set up the chain on the first message if it failed before
|
| 265 |
-
setup_rag_chain()
|
| 266 |
-
if rag_chain is None:
|
| 267 |
-
return "Samahani, mfumo wa lugha haukuweza kupakiwa. Tafadhali jaribu tena baadaye."
|
| 268 |
-
|
| 269 |
-
# The history is not used for RAG, as it's a simple QA chain.
|
| 270 |
-
# For a conversational model, history would be included in the prompt.
|
| 271 |
-
response = rag_chain(message)
|
| 272 |
return response
|
| 273 |
|
| 274 |
-
# Initialize the
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
# Define the Gradio interface
|
| 278 |
-
if rag_chain is not None:
|
| 279 |
gr.ChatInterface(
|
| 280 |
-
fn=
|
| 281 |
-
title="
|
| 282 |
-
description=
|
| 283 |
-
"Uliza maswali kuhusu maandishi ya 'A Nurse Toto' Episode 1 kwa Kiswahili au Sheng. "
|
| 284 |
-
"Mfumo huu unatumia **Retrieval-Augmented Generation (RAG)** na model ya Kiswahili "
|
| 285 |
-
f"kutoka Hugging Face ({MODEL_NAME}) kujibu maswali yako."
|
| 286 |
-
),
|
| 287 |
examples=[
|
| 288 |
-
["
|
| 289 |
-
["
|
| 290 |
-
["
|
| 291 |
-
["Nani alikuwa mroho kama magwanda ya mekanika?"],
|
| 292 |
["Mzee alitaka kufanya nini hospitalini?"],
|
| 293 |
]
|
| 294 |
).launch()
|
| 295 |
else:
|
|
|
|
| 296 |
gr.Interface(
|
| 297 |
-
fn=lambda x: "
|
| 298 |
inputs="text",
|
| 299 |
outputs="text",
|
| 300 |
title="Chatbot Initialization Failed"
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
from langchain_text_splitters import CharacterTextSplitter
|
| 5 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 6 |
from langchain_community.vectorstores import FAISS
|
|
|
|
| 7 |
|
| 8 |
# --- Configuration ---
|
| 9 |
+
# Switching to the smallest available Swahili model (1B) for guaranteed free CPU hosting
|
| 10 |
+
MODEL_NAME = "CraneAILabs/swahili-gemma-1b-litert"
|
| 11 |
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
| 12 |
|
| 13 |
+
# --- Transcript Data ---
|
|
|
|
|
|
|
| 14 |
NURSE_TOTO_TRANSCRIPT = """
|
| 15 |
# A Nurse Toto - Episode 1: Mzee wa Kutahirii (Kiswahili Transcript)
|
|
|
|
| 16 |
**Series:** A Nurse Toto
|
| 17 |
**Episode:** 1 - Mzee wa Kutahirii
|
| 18 |
**Creator:** Eddie Butita
|
|
|
|
| 76 |
**Maryanne:** Mzee, unajua unasumbua wewe? Hebu keti hapo. Utalipa 500 ya registration, utaona daktari na 1,000, alafu 15k, hiyo ni ya circumcision.
|
| 77 |
**Casypool:** Silipi kitu, niko na insurance.
|
| 78 |
**Maryanne:** Ni sawa, uko na insurance. But sasa sijui kama insurance inakava wazee wa umri yako kutahiri. Utangoja hapo usikie kama watakubali.
|
| 79 |
+
**Casipul:** Sasa, kitu ya kutokutahiri, utaenda kutangazia insurance ati sijatahiri?
|
| 80 |
**Maryanne:** Mzee, lakini vitu zingine ni za kujisimamia. Hizi ni aibu gani za ati, "Oh, mzee wa 52 years, circumcision na NHIF." Surely. Surely.
|
| 81 |
|
| 82 |
---
|
|
|
|
| 148 |
**Sly:** Ndio maana ulikuwa unasema tungoje, sindio?
|
| 149 |
"""
|
| 150 |
|
| 151 |
+
# --- Global Variables ---
|
|
|
|
| 152 |
tokenizer = None
|
| 153 |
model = None
|
| 154 |
+
vector_db = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
def setup_system():
|
| 157 |
+
"""Initializes the LLM and the Vector Database for RAG."""
|
| 158 |
+
global tokenizer, model, vector_db
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
try:
|
| 161 |
print(f"Loading tokenizer and model: {MODEL_NAME}...")
|
| 162 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 163 |
+
|
| 164 |
+
# Load model explicitly on CPU with a memory-safe dtype
|
| 165 |
+
# We are using the smallest available model (1B) to maximize chances of success on the free tier.
|
| 166 |
model = AutoModelForCausalLM.from_pretrained(
|
| 167 |
MODEL_NAME,
|
| 168 |
+
torch_dtype=torch.float32, # Safer for CPU-only environments
|
| 169 |
+
device_map="cpu" # Explicitly set to CPU to avoid auto-detection issues
|
|
|
|
| 170 |
)
|
| 171 |
+
model.eval()
|
| 172 |
print("Model loaded successfully.")
|
| 173 |
+
|
| 174 |
+
# Setup Vector DB for RAG
|
| 175 |
+
text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=1000, chunk_overlap=200)
|
| 176 |
+
texts = text_splitter.create_documents([NURSE_TOTO_TRANSCRIPT])
|
| 177 |
+
|
| 178 |
+
print("Creating embeddings and vector store...")
|
| 179 |
+
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
| 180 |
+
vector_db = FAISS.from_documents(texts, embeddings)
|
| 181 |
+
print("System setup complete.")
|
| 182 |
+
return True
|
| 183 |
except Exception as e:
|
| 184 |
+
print(f"FATAL ERROR: Model loading failed. This is likely due to memory constraints. Error: {e}")
|
| 185 |
+
# If model loading fails, we cannot proceed with the chatbot.
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
def generate_response(message, history):
|
| 189 |
+
"""Main chat function supporting both general chat and RAG."""
|
| 190 |
+
# Check if the model is loaded. If not, return the error message.
|
| 191 |
+
if model is None:
|
| 192 |
+
return "Samahani, mfumo wa lugha haukuweza kupakiwa kwa sababu ya matatizo ya kumbukumbu (memory issues). Tafadhali jaribu tena baadaye au tumia mfumo mdogo zaidi."
|
| 193 |
+
|
| 194 |
+
# 1. Retrieve relevant context from the transcript
|
| 195 |
+
docs = vector_db.similarity_search(message, k=2)
|
| 196 |
+
context = "\n".join([doc.page_content for doc in docs])
|
| 197 |
+
|
| 198 |
+
# 2. Construct the prompt
|
| 199 |
+
# We provide the context but instruct the model it can also chat generally.
|
| 200 |
+
system_prompt = (
|
| 201 |
+
"Wewe ni msaidizi wa AI unayezungumza Kiswahili na Sheng. "
|
| 202 |
+
"Unaweza kufanya mazungumzo ya kawaida au kujibu maswali kuhusu 'Nurse Toto' "
|
| 203 |
+
"kwa kutumia muktadha uliotolewa hapa chini. "
|
| 204 |
+
"Ikiwa swali halihusiani na Nurse Toto, jibu kwa kutumia maarifa yako ya jumla."
|
| 205 |
)
|
| 206 |
+
|
| 207 |
+
full_prompt = f"{system_prompt}\n\nMuktadha wa Nurse Toto:\n{context}\n\nUser: {message}\nAssistant:"
|
| 208 |
+
|
| 209 |
+
# 3. Generate
|
| 210 |
+
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
|
| 211 |
+
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
outputs = model.generate(
|
| 214 |
+
**inputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
max_new_tokens=256,
|
| 216 |
do_sample=True,
|
| 217 |
temperature=0.7,
|
| 218 |
+
top_p=0.9
|
| 219 |
)
|
| 220 |
+
|
| 221 |
+
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 222 |
+
|
| 223 |
+
# Extract only the assistant's response
|
| 224 |
+
response = full_output.split("Assistant:")[-1].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
return response
|
| 226 |
|
| 227 |
+
# Initialize the system. If it fails, the model will be None and the chat function will return an error.
|
| 228 |
+
if setup_system():
|
| 229 |
+
# Launch Gradio only if setup was successful
|
|
|
|
|
|
|
| 230 |
gr.ChatInterface(
|
| 231 |
+
fn=generate_response,
|
| 232 |
+
title="Lightweight Swahili/Sheng Chatbot (Nurse Toto RAG)",
|
| 233 |
+
description="Chat na AI kwa Kiswahili au Sheng! Inajua mambo ya Nurse Toto na mambo mengine ya kawaida.",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
examples=[
|
| 235 |
+
["Habari yako? Unaweza kunisaidia nini leo?"],
|
| 236 |
+
["Nieleze kuhusu Casypool kwenye Nurse Toto."],
|
| 237 |
+
["Sheng ya 'How are you' ni gani?"],
|
|
|
|
| 238 |
["Mzee alitaka kufanya nini hospitalini?"],
|
| 239 |
]
|
| 240 |
).launch()
|
| 241 |
else:
|
| 242 |
+
# If setup fails, launch a simple interface with an error message
|
| 243 |
gr.Interface(
|
| 244 |
+
fn=lambda x: "Samahani, mfumo wa lugha haukuweza kupakiwa kwa sababu ya matatizo ya kumbukumbu (memory issues). Tafadhali jaribu tena baadaye au tumia mfumo mdogo zaidi.",
|
| 245 |
inputs="text",
|
| 246 |
outputs="text",
|
| 247 |
title="Chatbot Initialization Failed"
|