Chia Woon Yap commited on
Commit
4e17344
·
verified ·
1 Parent(s): 90d9abe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +516 -0
app.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app
3
+ Automatically generated by Colab.
4
+ Original file is located at
5
+ https://colab.research.google.com/drive/1pwwcBb5Zlw1DA3u5K8W8mjrwBTBWXc1L
6
+ """
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ from transformers import pipeline
11
+ import os
12
+ import time
13
+ import groq
14
+ import uuid
15
+
16
+ # LangChain imports
17
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
18
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
19
+ from langchain_core.documents import Document
20
+ from langchain_community.embeddings import HuggingFaceEmbeddings
21
+ from langchain_community.vectorstores import Chroma
22
+ from langchain_groq import ChatGroq
23
+
24
+ # Other imports
25
+ import chardet
26
+ import fitz # PyMuPDF for PDFs
27
+ import docx # python-docx for Word files
28
+ import gtts # Google Text-to-Speech library
29
+ from pptx import Presentation # python-pptx for PowerPoint files
30
+ import re
31
+
32
+ import torch
33
+ import torchaudio
34
+ from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration
35
+
36
+ # Set API Key
37
+ groq.api_key = os.getenv("GROQ_API_KEY")
38
+
39
+ # Initialize Chat Model
40
+ chat_model = ChatGroq(model_name="llama-3.3-70b-versatile", api_key=groq.api_key)
41
+
42
+ # Initialize Embeddings and chromaDB
43
+ os.makedirs("chroma_db", exist_ok=True)
44
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
45
+ vectorstore = Chroma(
46
+ embedding_function=embedding_model,
47
+ persist_directory="chroma_db"
48
+ )
49
+
50
+ # Short-term memory for the LLM
51
+ chat_memory = []
52
+
53
+ # Prompt for quiz generation
54
+ quiz_prompt = """
55
+ You are an AI assistant specialized in education and assessment creation. Given an uploaded document or text, generate a quiz with a mix of multiple-choice questions (MCQs) and fill-in-the-blank questions. The quiz should be directly based on the key concepts, facts, and details from the provided material.
56
+ Generate 20 Questions.
57
+ Remove all unnecessary formatting generated by the LLM, including <think> tags, asterisks, markdown formatting, and any bold or italic text, as well as **, ###, ##, and # tags.
58
+ For each question:
59
+ - Provide 4 answer choices (for MCQs), with only one correct answer.
60
+ - Ensure fill-in-the-blank questions focus on key terms, phrases, or concepts from the document.
61
+ - Include an answer key for all questions.
62
+ - Ensure questions vary in difficulty and encourage comprehension rather than memorization.
63
+ - Additionally, implement an instant feedback mechanism:
64
+ - When a user selects an answer, indicate whether it is correct or incorrect.
65
+ - If incorrect, provide a brief explanation from the document to guide learning.
66
+ - Ensure responses are concise and educational to enhance understanding.
67
+ Output Example:
68
+ 1. Fill in the blank: The LLM Agent framework has a central decision-making unit called the _______________________.
69
+ Answer: Agent Core
70
+ Feedback: The Agent Core is the central component of the LLM Agent framework, responsible for managing goals, tool instructions, planning modules, memory integration, and agent persona.
71
+ 2. What is the main limitation of LLM-based applications?
72
+ a) Limited token capacity
73
+ b) Lack of domain expertise
74
+ c) Prone to hallucination
75
+ d) All of the above
76
+ Answer: d) All of the above
77
+ Feedback: LLM-based applications have several limitations, including limited token capacity, lack of domain expertise, and being prone to hallucination, among others.
78
+ 3. Given the following info, what is the value of P(jam|Rain)?
79
+ P(no Rain) = 0.8;
80
+ P(no Jam) = 0.2;
81
+ P(Rain|Jam) = 0.1
82
+ a) 0.016
83
+ b) 0.025
84
+ c) 0.1
85
+ d) 0.4
86
+ Answer: d) 0.4
87
+ Feedback: This question tests understanding of Bayes' Theorem by requiring the calculation of conditional probability using the given values.
88
+ """
89
+
90
+ # Enhanced Whisper Transcriber with Chunked Processing
91
+ class EnhancedWhisperTranscriber:
92
+ def __init__(self, model_name=None):
93
+ # Auto-select optimal model based on hardware
94
+ if model_name is None:
95
+ model_name = self.get_optimal_model()
96
+
97
+ self.device = 0 if torch.cuda.is_available() else "cpu"
98
+ self.model_name = model_name
99
+
100
+ print(f"Initializing Whisper model: {model_name} on {self.device}")
101
+
102
+ self.pipe = pipeline(
103
+ task="automatic-speech-recognition",
104
+ model=model_name,
105
+ chunk_length_s=30, # Process in 30-second chunks
106
+ stride_length_s=5, # 5-second overlap between chunks
107
+ device=self.device,
108
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
109
+ )
110
+
111
+ def get_optimal_model(self):
112
+ """Automatically select the best model for available hardware"""
113
+ if torch.cuda.is_available():
114
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
115
+ if gpu_memory > 8: # 8GB+ VRAM
116
+ return "openai/whisper-small.en"
117
+ else: # Limited VRAM
118
+ return "openai/whisper-base.en"
119
+ else: # CPU only
120
+ return "openai/whisper-base.en" # Balanced choice for CPU
121
+
122
+ def transcribe_numpy(self, sr, y, return_timestamps=False):
123
+ """Transcribe numpy array audio with chunked processing"""
124
+ try:
125
+ # Enhanced audio preprocessing
126
+ if y.ndim > 1:
127
+ y = y.mean(axis=1) # Convert to mono
128
+
129
+ y = y.astype(np.float32)
130
+
131
+ # Normalize audio
132
+ max_val = np.max(np.abs(y))
133
+ if max_val > 0:
134
+ y = y / max_val
135
+
136
+ # Remove silence (simple threshold-based)
137
+ silence_threshold = 0.01
138
+ non_silent_indices = np.where(np.abs(y) > silence_threshold)[0]
139
+
140
+ if len(non_silent_indices) == 0:
141
+ return "No speech detected. Please speak louder or check your microphone."
142
+
143
+ # Trim silence from beginning and end
144
+ start_idx = non_silent_indices[0]
145
+ end_idx = non_silent_indices[-1]
146
+ y_trimmed = y[start_idx:end_idx+1]
147
+
148
+ # Check if audio is too short
149
+ if len(y_trimmed) / sr < 0.5: # Less than 0.5 seconds
150
+ return "Audio too short. Please speak for at least 1-2 seconds."
151
+
152
+ # Create audio dict for pipeline
153
+ inputs = {"array": y_trimmed, "sampling_rate": sr}
154
+
155
+ # Enhanced transcription with chunked processing
156
+ result = self.pipe(
157
+ inputs,
158
+ batch_size=4, # Optimal batch size for chunked processing
159
+ generate_kwargs={"task": "transcribe"},
160
+ return_timestamps=return_timestamps
161
+ )
162
+
163
+ text = result["text"].strip()
164
+
165
+ if not text:
166
+ return "No clear speech detected. Try speaking more clearly or in a quieter environment."
167
+
168
+ return text
169
+
170
+ except Exception as e:
171
+ error_msg = f"Transcription error: {str(e)}"
172
+ print(error_msg)
173
+ return f"Sorry, I couldn't process the audio. Please try again or type your message instead."
174
+
175
+ # Initialize the enhanced transcriber
176
+ transcriber = EnhancedWhisperTranscriber()
177
+
178
+ def get_transcription_status(audio):
179
+ """Provide status feedback for transcription"""
180
+ if audio is None:
181
+ return "Ready to record audio"
182
+
183
+ sr, y = audio
184
+ duration = len(y) / sr if sr > 0 else 0
185
+
186
+ if duration < 0.5:
187
+ return "Audio too short - please record at least 1 second"
188
+ elif duration > 60 and not torch.cuda.is_available():
189
+ return "Long audio detected on CPU - this may take a while..."
190
+ else:
191
+ device = "GPU" if torch.cuda.is_available() else "CPU"
192
+ return f"Processing {duration:.1f}s audio on {device}..."
193
+
194
+ def transcribe_audio(audio):
195
+ """Main transcription function with progress feedback"""
196
+ if audio is None:
197
+ return "Please record audio first"
198
+
199
+ # Show device info for debugging
200
+ device_type = "GPU" if torch.cuda.is_available() else "CPU"
201
+ print(f"Transcribing on {device_type} using {transcriber.model_name}")
202
+
203
+ sr, y = audio
204
+
205
+ # For CPU users, we might want to show a warning for long audio
206
+ audio_duration = len(y) / sr if sr > 0 else 0
207
+ if not torch.cuda.is_available() and audio_duration > 30: # Longer than 30 seconds on CPU
208
+ print("Warning: Long audio on CPU - transcription may take a while...")
209
+
210
+ # Use the enhanced transcriber
211
+ result = transcriber.transcribe_numpy(sr, y)
212
+
213
+ # Log transcription result for debugging
214
+ print(f"Transcription result: {result[:100]}...")
215
+
216
+ return result
217
+
218
+ # Function to clean AI response by removing unwanted formatting
219
+ def clean_response(response):
220
+ """Removes <think> tags, asterisks, and markdown formatting."""
221
+ cleaned_text = re.sub(r"<think>.*?</think>", "", response, flags=re.DOTALL)
222
+ cleaned_text = re.sub(r"(\*\*|\*|\[|\])", "", cleaned_text)
223
+ cleaned_text = re.sub(r"^##+\s*", "", cleaned_text, flags=re.MULTILINE)
224
+ cleaned_text = re.sub(r"\\", "", cleaned_text)
225
+ cleaned_text = re.sub(r"---", "", cleaned_text)
226
+ return cleaned_text.strip()
227
+
228
+ # Function to generate quiz based on content
229
+ def generate_quiz(content):
230
+ prompt = f"{quiz_prompt}\n\nDocument content:\n{content}"
231
+ response = chat_model.invoke([HumanMessage(content=prompt)])
232
+ cleaned_response = clean_response(response.content)
233
+ return cleaned_response
234
+
235
+ # Function to retrieve relevant documents from vectorstore based on user query
236
+ def retrieve_documents(query):
237
+ results = vectorstore.similarity_search(query, k=3)
238
+ return [doc.page_content for doc in results]
239
+
240
+ # Function to convert tuple format to message format
241
+ def convert_to_message_format(chat_history):
242
+ message_format = []
243
+ for user_msg, bot_msg in chat_history:
244
+ message_format.append({"role": "user", "content": user_msg})
245
+ message_format.append({"role": "assistant", "content": bot_msg})
246
+ return message_format
247
+
248
+ # Function to convert message format to tuple format for processing
249
+ def convert_to_tuple_format(chat_history):
250
+ tuple_format = []
251
+ for i in range(0, len(chat_history), 2):
252
+ if i+1 < len(chat_history):
253
+ user_msg = chat_history[i]["content"]
254
+ bot_msg = chat_history[i+1]["content"]
255
+ tuple_format.append((user_msg, bot_msg))
256
+ return tuple_format
257
+
258
+ # Function to handle chatbot interactions with short-term memory
259
+ def chat_with_groq(user_input, chat_history):
260
+ try:
261
+ # Convert message format to tuple format for processing
262
+ tuple_history = convert_to_tuple_format(chat_history)
263
+
264
+ # Retrieve relevant documents for additional context
265
+ relevant_docs = retrieve_documents(user_input)
266
+ context = "\n".join(relevant_docs) if relevant_docs else "No relevant documents found."
267
+
268
+ # Construct proper prompting with conversation history
269
+ system_prompt = "You are a helpful AI assistant. Answer questions accurately and concisely."
270
+ conversation_history = "\n".join(chat_memory[-10:])
271
+ prompt = f"{system_prompt}\n\nConversation History:\n{conversation_history}\n\nUser Input: {user_input}\n\nContext:\n{context}"
272
+
273
+ # Call the chat model
274
+ response = chat_model.invoke([HumanMessage(content=prompt)])
275
+
276
+ # Clean response to remove any unwanted formatting
277
+ cleaned_response_text = clean_response(response.content)
278
+
279
+ # Append conversation history
280
+ chat_memory.append(f"User: {user_input}")
281
+ chat_memory.append(f"AI: {cleaned_response_text}")
282
+
283
+ # Update chat history
284
+ chat_history.append({"role": "user", "content": user_input})
285
+ chat_history.append({"role": "assistant", "content": cleaned_response_text})
286
+
287
+ # Convert response to speech
288
+ audio_file = speech_playback(cleaned_response_text)
289
+
290
+ return chat_history, "", audio_file
291
+ except Exception as e:
292
+ error_msg = f"Error: {str(e)}"
293
+ chat_history.append({"role": "user", "content": user_input})
294
+ chat_history.append({"role": "assistant", "content": error_msg})
295
+ return chat_history, "", None
296
+
297
+ # Function to play response as speech using gTTS
298
+ def speech_playback(text):
299
+ try:
300
+ # Generate a unique filename for each audio file
301
+ unique_id = str(uuid.uuid4())
302
+ audio_file = f"output_audio_{unique_id}.mp3"
303
+
304
+ # Convert text to speech
305
+ tts = gtts.gTTS(text, lang='en')
306
+ tts.save(audio_file)
307
+
308
+ # Return the path to the audio file
309
+ return audio_file
310
+ except Exception as e:
311
+ print(f"Error in speech_playback: {e}")
312
+ return None
313
+
314
+ # Function to detect encoding safely
315
+ def detect_encoding(file_path):
316
+ try:
317
+ with open(file_path, "rb") as f:
318
+ raw_data = f.read(4096)
319
+ detected = chardet.detect(raw_data)
320
+ encoding = detected["encoding"]
321
+ return encoding if encoding else "utf-8"
322
+ except Exception:
323
+ return "utf-8"
324
+
325
+ # Function to extract text from PDF
326
+ def extract_text_from_pdf(pdf_path):
327
+ try:
328
+ doc = fitz.open(pdf_path)
329
+ text = "\n".join([page.get_text("text") for page in doc])
330
+ return text if text.strip() else "No extractable text found."
331
+ except Exception as e:
332
+ return f"Error extracting text from PDF: {str(e)}"
333
+
334
+ # Function to extract text from Word files (.docx)
335
+ def extract_text_from_docx(docx_path):
336
+ try:
337
+ doc = docx.Document(docx_path)
338
+ text = "\n".join([para.text for para in doc.paragraphs])
339
+ return text if text.strip() else "No extractable text found."
340
+ except Exception as e:
341
+ return f"Error extracting text from Word document: {str(e)}"
342
+
343
+ # Function to extract text from PowerPoint files (.pptx)
344
+ def extract_text_from_pptx(pptx_path):
345
+ try:
346
+ presentation = Presentation(pptx_path)
347
+ text = ""
348
+ for slide in presentation.slides:
349
+ for shape in slide.shapes:
350
+ if hasattr(shape, "text"):
351
+ text += shape.text + "\n"
352
+ return text if text.strip() else "No extractable text found."
353
+ except Exception as e:
354
+ return f"Error extracting text from PowerPoint: {str(e)}"
355
+
356
+ # Function to process documents safely
357
+ def process_document(file):
358
+ try:
359
+ file_extension = os.path.splitext(file.name)[-1].lower()
360
+ if file_extension in [".png", ".jpg", ".jpeg"]:
361
+ return "Error: Images cannot be processed for text extraction."
362
+ if file_extension == ".pdf":
363
+ content = extract_text_from_pdf(file.name)
364
+ elif file_extension == ".docx":
365
+ content = extract_text_from_docx(file.name)
366
+ elif file_extension == ".pptx":
367
+ content = extract_text_from_pptx(file.name)
368
+ else:
369
+ encoding = detect_encoding(file.name)
370
+ with open(file.name, "r", encoding=encoding, errors="replace") as f:
371
+ content = f.read()
372
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
373
+ documents = [Document(page_content=chunk) for chunk in text_splitter.split_text(content)]
374
+ vectorstore.add_documents(documents)
375
+
376
+ quiz = generate_quiz(content)
377
+ return f"Document processed successfully (File Type: {file_extension}). Quiz generated:\n{quiz}"
378
+ except Exception as e:
379
+ return f"Error processing document: {str(e)}"
380
+
381
+ # Clear chat history function
382
+ def clear_chat_history():
383
+ chat_memory.clear()
384
+ return [], None
385
+
386
+ def tutor_ai_chatbot():
387
+ """Main Gradio interface for the Tutor AI Chatbot."""
388
+ with gr.Blocks() as app:
389
+ gr.Markdown("# AI Tutor - We.(POC)")
390
+ gr.Markdown("An interactive Personal AI Tutor chatbot to help with your learning needs.")
391
+
392
+ # Chatbot Tab
393
+ with gr.Tab("AI Chatbot"):
394
+ with gr.Row():
395
+ with gr.Column(scale=3):
396
+ chatbot = gr.Chatbot(height=500, type="messages")
397
+
398
+ with gr.Column(scale=1):
399
+ audio_playback = gr.Audio(label="Audio Response", type="filepath")
400
+
401
+ # Move the input controls here to span full width
402
+ with gr.Row():
403
+ msg = gr.Textbox(
404
+ label="Ask a question",
405
+ placeholder="Type your question here...",
406
+ container=False
407
+ )
408
+ submit = gr.Button("Send")
409
+
410
+ with gr.Row():
411
+ with gr.Column(scale=1):
412
+ audio_input = gr.Audio(type="numpy", label="Record or Upload Audio")
413
+
414
+ # Add transcription status indicator
415
+ transcription_status = gr.Textbox(
416
+ label="Transcription Status",
417
+ interactive=False,
418
+ value="Record audio to see status here",
419
+ max_lines=2
420
+ )
421
+
422
+ # Voice recording tips - ONLY in AI Chatbot tab
423
+ with gr.Accordion("Voice Recording Tips", open=False):
424
+ gr.Markdown("""
425
+ **For better speech recognition accuracy:**
426
+ - Speak clearly and at a moderate pace
427
+ - Record in a quiet environment
428
+ - Keep the microphone close to your mouth (10-15 cm)
429
+ - Use a good quality microphone if possible
430
+ - Review the transcribed text before sending
431
+ - If transcription is poor, try recording again or type manually
432
+
433
+ **Performance Info:**
434
+ - GPU: Fast transcription (2-5 seconds)
435
+ - CPU: Slower but functional (10-30 seconds for longer audio)
436
+ - Using model: whisper-base.en (optimized for accuracy/speed balance)
437
+ """)
438
+
439
+ # Clear chat history button
440
+ clear_btn = gr.Button("Clear Chat")
441
+
442
+ # Handle chat interaction
443
+ submit.click(
444
+ chat_with_groq,
445
+ inputs=[msg, chatbot],
446
+ outputs=[chatbot, msg, audio_playback]
447
+ )
448
+
449
+ # Clear chat history function
450
+ clear_btn.click(
451
+ lambda: [],
452
+ inputs=None,
453
+ outputs=[chatbot]
454
+ )
455
+
456
+ # Also allow Enter key to submit
457
+ msg.submit(
458
+ chat_with_groq,
459
+ inputs=[msg, chatbot],
460
+ outputs=[chatbot, msg, audio_playback]
461
+ )
462
+
463
+ # Add some examples of questions students might ask
464
+ with gr.Accordion("Example Questions", open=False):
465
+ gr.Examples(
466
+ examples=[
467
+ "Can you explain the concept of RLHF AI?",
468
+ "What are AI transformers?",
469
+ "What is MoE AI?",
470
+ "What's gate networks AI?",
471
+ "I am making a switch, please generating baking recipe?"
472
+ ],
473
+ inputs=msg
474
+ )
475
+
476
+ # Connect audio input to transcription with status updates
477
+ audio_input.change(
478
+ fn=get_transcription_status,
479
+ inputs=audio_input,
480
+ outputs=transcription_status
481
+ ).then(
482
+ fn=transcribe_audio,
483
+ inputs=audio_input,
484
+ outputs=msg
485
+ ).then(
486
+ fn=lambda x: "Transcription completed!" if x and x != "Please record audio first" else "Ready for new recording",
487
+ inputs=msg,
488
+ outputs=transcription_status
489
+ )
490
+
491
+ # Upload Notes & Generate Quiz Tab
492
+ with gr.Tab("Upload Notes & Generate Quiz"):
493
+ with gr.Row():
494
+ with gr.Column(scale=2):
495
+ file_input = gr.File(label="Upload Lecture Notes (PDF, DOCX, PPTX)")
496
+ with gr.Column(scale=3):
497
+ quiz_output = gr.Textbox(label="Generated Quiz", lines=10)
498
+
499
+ # Connect file input to document processing
500
+ file_input.change(process_document, inputs=file_input, outputs=quiz_output)
501
+
502
+ # Introduction Video Tab - Now with the working video
503
+ with gr.Tab("Introduction Video"):
504
+ with gr.Row():
505
+ with gr.Column(scale=1):
506
+ gr.Markdown("### Welcome to the Introduction Video")
507
+ gr.Markdown("Music from Xu Mengyuan - China-O, musician Xu Mengyuan YUAN! | 徐梦圆 - China-O 音乐人徐梦圆YUAN!")
508
+ # Use the local video file that's stored in your Space
509
+ gr.Video("We_not_me_video.mp4", label="Introduction Video")
510
+
511
+ # Launch the application
512
+ app.launch(share=False)
513
+
514
+ # Launch the AI chatbot
515
+ if __name__ == "__main__":
516
+ tutor_ai_chatbot()