Stanley03 commited on
Commit
4666f34
·
verified ·
1 Parent(s): 89509ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -126
app.py CHANGED
@@ -1,23 +1,18 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from langchain_text_splitters import CharacterTextSplitter
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
- # from langchain.chains import RetrievalQA # Not used in this RAG implementation
8
 
9
  # --- Configuration ---
10
- MODEL_NAME = "Jacaranda/UlizaLlama3" # Best Swahili LLM, but may require a paid GPU Space
11
- # Alternative for free CPU Space: "CraneAILabs/swahili-gemma-1b-litert"
12
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
- TRANSCRIPT_FILE = "nurse_toto_episode_1_transcript.md"
14
 
15
- # --- Transcript Data (for RAG) ---
16
- # The full transcript is loaded here. In a real scenario, this would be loaded from a file.
17
- # For simplicity and deployment, we'll embed the content directly.
18
  NURSE_TOTO_TRANSCRIPT = """
19
  # A Nurse Toto - Episode 1: Mzee wa Kutahirii (Kiswahili Transcript)
20
-
21
  **Series:** A Nurse Toto
22
  **Episode:** 1 - Mzee wa Kutahirii
23
  **Creator:** Eddie Butita
@@ -81,7 +76,7 @@ NURSE_TOTO_TRANSCRIPT = """
81
  **Maryanne:** Mzee, unajua unasumbua wewe? Hebu keti hapo. Utalipa 500 ya registration, utaona daktari na 1,000, alafu 15k, hiyo ni ya circumcision.
82
  **Casypool:** Silipi kitu, niko na insurance.
83
  **Maryanne:** Ni sawa, uko na insurance. But sasa sijui kama insurance inakava wazee wa umri yako kutahiri. Utangoja hapo usikie kama watakubali.
84
- **Casypool:** Sasa, kitu ya kutokutahiri, utaenda kutangazia insurance ati sijatahiri?
85
  **Maryanne:** Mzee, lakini vitu zingine ni za kujisimamia. Hizi ni aibu gani za ati, "Oh, mzee wa 52 years, circumcision na NHIF." Surely. Surely.
86
 
87
  ---
@@ -153,148 +148,100 @@ NURSE_TOTO_TRANSCRIPT = """
153
  **Sly:** Ndio maana ulikuwa unasema tungoje, sindio?
154
  """
155
 
156
- # --- Model and RAG Setup ---
157
- # Global variables to hold the model and RAG chain
158
  tokenizer = None
159
  model = None
160
- rag_chain = None
161
-
162
- def setup_rag_chain():
163
- """Initializes the LLM, tokenizer, and RAG chain."""
164
- global tokenizer, model, rag_chain
165
 
166
- if rag_chain is not None:
167
- return
 
168
 
169
- # 1. Load the Swahili LLM (using a smaller model for deployment)
170
- # Note: For a free Hugging Face Space, a small model is necessary.
171
- # The UlizaLlama3 is 8B and will likely require a paid GPU.
172
- # We will use a placeholder for the code, but advise the user.
173
  try:
174
  print(f"Loading tokenizer and model: {MODEL_NAME}...")
175
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
176
- # Load in 4-bit for memory efficiency
 
 
177
  model = AutoModelForCausalLM.from_pretrained(
178
  MODEL_NAME,
179
- load_in_4bit=True,
180
- torch_dtype=torch.bfloat16,
181
- device_map="auto"
182
  )
 
183
  print("Model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
184
  except Exception as e:
185
- print(f"Error loading model {MODEL_NAME}. Falling back to a dummy model. Error: {e}")
186
- # Fallback for local testing or if the model is too large for the environment
187
- def dummy_llm(prompt):
188
- return "Samahani, mfumo wa lugha haupatikani. Hata hivyo, ninaweza kujibu maswali kuhusu 'Nurse Toto' kulingana na maandishi."
189
- rag_chain = dummy_llm
190
- return
191
-
192
- # 2. Create documents from the transcript
193
- text_splitter = CharacterTextSplitter(
194
- separator="\n\n",
195
- chunk_size=1000,
196
- chunk_overlap=200,
197
- length_function=len,
 
 
 
 
 
 
 
 
198
  )
199
- texts = text_splitter.create_documents([NURSE_TOTO_TRANSCRIPT])
200
-
201
- # 3. Create embeddings and vector store
202
- print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}...")
203
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
204
- print("Creating FAISS vector store...")
205
- db = FAISS.from_documents(texts, embeddings)
206
- retriever = db.as_retriever(search_kwargs={"k": 3})
207
-
208
- # 4. Setup the RAG chain
209
- # We'll use a simple pipeline for generation and integrate it with the retriever manually
210
- # to avoid complex LangChain dependencies that might fail on a free Space.
211
-
212
- # A simple function to format the prompt for the LLM
213
- def format_prompt(context, question):
214
- # This is a general instruction prompt for the LLM
215
- system_prompt = (
216
- "Wewe ni mtaalamu wa mazungumzo ya Kiswahili na Sheng. "
217
- "Jibu maswali ya mtumiaji kwa kutumia muktadha uliotolewa kutoka kwa "
218
- "maandishi ya 'A Nurse Toto' Episode 1. Ikiwa jibu halipatikani kwenye "
219
- "muktadha, jibu kwa heshima kwamba huna habari hiyo, lakini bado "
220
- "tumia lugha ya Kiswahili au Sheng."
221
- )
222
- return f"{system_prompt}\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:"
223
-
224
- # A simple function to run the RAG process
225
- def rag_qa(question):
226
- # 1. Retrieve context
227
- docs = retriever.get_relevant_documents(question)
228
- context = "\n---\n".join([doc.page_content for doc in docs])
229
-
230
- # 2. Format prompt
231
- prompt = format_prompt(context, question)
232
-
233
- # 3. Generate response
234
- # Using the Hugging Face pipeline for text generation
235
- pipe = pipeline(
236
- "text-generation",
237
- model=model,
238
- tokenizer=tokenizer,
239
  max_new_tokens=256,
240
  do_sample=True,
241
  temperature=0.7,
242
- top_p=0.9,
243
  )
244
-
245
- # The model will generate the prompt and the answer, so we need to clean the output
246
- output = pipe(prompt)[0]['generated_text']
247
-
248
- # Simple cleaning to extract only the answer part
249
- if "Answer:" in output:
250
- answer = output.split("Answer:", 1)[-1].strip()
251
- else:
252
- answer = output.split(prompt, 1)[-1].strip() # Fallback
253
-
254
- return answer
255
-
256
- rag_chain = rag_qa
257
- print("RAG chain initialized.")
258
-
259
- # --- Gradio Interface ---
260
-
261
- def chat_function(message, history):
262
- """The main function for the Gradio chat interface."""
263
- if rag_chain is None:
264
- # Attempt to set up the chain on the first message if it failed before
265
- setup_rag_chain()
266
- if rag_chain is None:
267
- return "Samahani, mfumo wa lugha haukuweza kupakiwa. Tafadhali jaribu tena baadaye."
268
-
269
- # The history is not used for RAG, as it's a simple QA chain.
270
- # For a conversational model, history would be included in the prompt.
271
- response = rag_chain(message)
272
  return response
273
 
274
- # Initialize the RAG chain on startup
275
- setup_rag_chain()
276
-
277
- # Define the Gradio interface
278
- if rag_chain is not None:
279
  gr.ChatInterface(
280
- fn=chat_function,
281
- title="Nurse Toto Kiswahili/Sheng Chatbot (RAG)",
282
- description=(
283
- "Uliza maswali kuhusu maandishi ya 'A Nurse Toto' Episode 1 kwa Kiswahili au Sheng. "
284
- "Mfumo huu unatumia **Retrieval-Augmented Generation (RAG)** na model ya Kiswahili "
285
- f"kutoka Hugging Face ({MODEL_NAME}) kujibu maswali yako."
286
- ),
287
  examples=[
288
- ["Casypool ana miaka mingapi?"],
289
- ["Wambo na Sly walisema nini kuhusu mgonjwa?"],
290
- ["Mzee alikula nini jana?"],
291
- ["Nani alikuwa mroho kama magwanda ya mekanika?"],
292
  ["Mzee alitaka kufanya nini hospitalini?"],
293
  ]
294
  ).launch()
295
  else:
 
296
  gr.Interface(
297
- fn=lambda x: "Model loading failed. Check logs for details.",
298
  inputs="text",
299
  outputs="text",
300
  title="Chatbot Initialization Failed"
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from langchain_text_splitters import CharacterTextSplitter
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from langchain_community.vectorstores import FAISS
 
7
 
8
  # --- Configuration ---
9
+ # Switching to the smallest available Swahili model (1B) for guaranteed free CPU hosting
10
+ MODEL_NAME = "CraneAILabs/swahili-gemma-1b-litert"
11
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
 
12
 
13
+ # --- Transcript Data ---
 
 
14
  NURSE_TOTO_TRANSCRIPT = """
15
  # A Nurse Toto - Episode 1: Mzee wa Kutahirii (Kiswahili Transcript)
 
16
  **Series:** A Nurse Toto
17
  **Episode:** 1 - Mzee wa Kutahirii
18
  **Creator:** Eddie Butita
 
76
  **Maryanne:** Mzee, unajua unasumbua wewe? Hebu keti hapo. Utalipa 500 ya registration, utaona daktari na 1,000, alafu 15k, hiyo ni ya circumcision.
77
  **Casypool:** Silipi kitu, niko na insurance.
78
  **Maryanne:** Ni sawa, uko na insurance. But sasa sijui kama insurance inakava wazee wa umri yako kutahiri. Utangoja hapo usikie kama watakubali.
79
+ **Casipul:** Sasa, kitu ya kutokutahiri, utaenda kutangazia insurance ati sijatahiri?
80
  **Maryanne:** Mzee, lakini vitu zingine ni za kujisimamia. Hizi ni aibu gani za ati, "Oh, mzee wa 52 years, circumcision na NHIF." Surely. Surely.
81
 
82
  ---
 
148
  **Sly:** Ndio maana ulikuwa unasema tungoje, sindio?
149
  """
150
 
151
+ # --- Global Variables ---
 
152
  tokenizer = None
153
  model = None
154
+ vector_db = None
 
 
 
 
155
 
156
+ def setup_system():
157
+ """Initializes the LLM and the Vector Database for RAG."""
158
+ global tokenizer, model, vector_db
159
 
 
 
 
 
160
  try:
161
  print(f"Loading tokenizer and model: {MODEL_NAME}...")
162
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
163
+
164
+ # Load model explicitly on CPU with a memory-safe dtype
165
+ # We are using the smallest available model (1B) to maximize chances of success on the free tier.
166
  model = AutoModelForCausalLM.from_pretrained(
167
  MODEL_NAME,
168
+ torch_dtype=torch.float32, # Safer for CPU-only environments
169
+ device_map="cpu" # Explicitly set to CPU to avoid auto-detection issues
 
170
  )
171
+ model.eval()
172
  print("Model loaded successfully.")
173
+
174
+ # Setup Vector DB for RAG
175
+ text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=1000, chunk_overlap=200)
176
+ texts = text_splitter.create_documents([NURSE_TOTO_TRANSCRIPT])
177
+
178
+ print("Creating embeddings and vector store...")
179
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
180
+ vector_db = FAISS.from_documents(texts, embeddings)
181
+ print("System setup complete.")
182
+ return True
183
  except Exception as e:
184
+ print(f"FATAL ERROR: Model loading failed. This is likely due to memory constraints. Error: {e}")
185
+ # If model loading fails, we cannot proceed with the chatbot.
186
+ return False
187
+
188
+ def generate_response(message, history):
189
+ """Main chat function supporting both general chat and RAG."""
190
+ # Check if the model is loaded. If not, return the error message.
191
+ if model is None:
192
+ return "Samahani, mfumo wa lugha haukuweza kupakiwa kwa sababu ya matatizo ya kumbukumbu (memory issues). Tafadhali jaribu tena baadaye au tumia mfumo mdogo zaidi."
193
+
194
+ # 1. Retrieve relevant context from the transcript
195
+ docs = vector_db.similarity_search(message, k=2)
196
+ context = "\n".join([doc.page_content for doc in docs])
197
+
198
+ # 2. Construct the prompt
199
+ # We provide the context but instruct the model it can also chat generally.
200
+ system_prompt = (
201
+ "Wewe ni msaidizi wa AI unayezungumza Kiswahili na Sheng. "
202
+ "Unaweza kufanya mazungumzo ya kawaida au kujibu maswali kuhusu 'Nurse Toto' "
203
+ "kwa kutumia muktadha uliotolewa hapa chini. "
204
+ "Ikiwa swali halihusiani na Nurse Toto, jibu kwa kutumia maarifa yako ya jumla."
205
  )
206
+
207
+ full_prompt = f"{system_prompt}\n\nMuktadha wa Nurse Toto:\n{context}\n\nUser: {message}\nAssistant:"
208
+
209
+ # 3. Generate
210
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
211
+
212
+ with torch.no_grad():
213
+ outputs = model.generate(
214
+ **inputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  max_new_tokens=256,
216
  do_sample=True,
217
  temperature=0.7,
218
+ top_p=0.9
219
  )
220
+
221
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
222
+
223
+ # Extract only the assistant's response
224
+ response = full_output.split("Assistant:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  return response
226
 
227
+ # Initialize the system. If it fails, the model will be None and the chat function will return an error.
228
+ if setup_system():
229
+ # Launch Gradio only if setup was successful
 
 
230
  gr.ChatInterface(
231
+ fn=generate_response,
232
+ title="Lightweight Swahili/Sheng Chatbot (Nurse Toto RAG)",
233
+ description="Chat na AI kwa Kiswahili au Sheng! Inajua mambo ya Nurse Toto na mambo mengine ya kawaida.",
 
 
 
 
234
  examples=[
235
+ ["Habari yako? Unaweza kunisaidia nini leo?"],
236
+ ["Nieleze kuhusu Casypool kwenye Nurse Toto."],
237
+ ["Sheng ya 'How are you' ni gani?"],
 
238
  ["Mzee alitaka kufanya nini hospitalini?"],
239
  ]
240
  ).launch()
241
  else:
242
+ # If setup fails, launch a simple interface with an error message
243
  gr.Interface(
244
+ fn=lambda x: "Samahani, mfumo wa lugha haukuweza kupakiwa kwa sababu ya matatizo ya kumbukumbu (memory issues). Tafadhali jaribu tena baadaye au tumia mfumo mdogo zaidi.",
245
  inputs="text",
246
  outputs="text",
247
  title="Chatbot Initialization Failed"