Chia Woon Yap commited on
Commit
aee7757
·
verified ·
1 Parent(s): 2a7cdbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -36
app.py CHANGED
@@ -16,13 +16,29 @@ import time
16
  import groq
17
  import uuid # For generating unique filenames
18
 
19
- # Updated imports to address LangChain deprecation warnings:
20
- from langchain_groq import ChatGroq
21
- from langchain.schema import HumanMessage
22
- from langchain.text_splitter import RecursiveCharacterTextSplitter
23
- from langchain_community.vectorstores import Chroma
24
- from langchain_community.embeddings import HuggingFaceEmbeddings
25
- from langchain.docstore.document import Document
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Importing chardet (make sure to add chardet to your requirements.txt)
28
  import chardet
@@ -41,7 +57,22 @@ import uvicorn
41
  from typing import Optional
42
  import io
43
  import soundfile as sf
44
- import librosa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Enhanced Whisper model for speech-to-text with better configuration
47
  try:
@@ -53,24 +84,44 @@ try:
53
  stride_length_s=5,
54
  batch_size=8
55
  )
 
56
  except Exception as e:
57
  print(f"Warning: Could not load enhanced Whisper model: {e}")
58
- # Fallback to basic model
59
- transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
 
 
 
 
 
60
 
61
  # Set API Key (Ensure it's stored securely in an environment variable)
62
  groq.api_key = os.getenv("GROQ_API_KEY")
63
 
64
- # Initialize Chat Model
65
- chat_model = ChatGroq(model_name="llama-3.3-70b-versatile", api_key=groq.api_key)
 
 
 
 
 
 
 
 
 
66
 
67
- # Initialize Embeddings and chromaDB
68
- os.makedirs("chroma_db", exist_ok=True)
69
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
70
- vectorstore = Chroma(
71
- embedding_function=embedding_model,
72
- persist_directory="chroma_db"
73
- )
 
 
 
 
 
74
 
75
  # Short-term memory for the LLM
76
  chat_memory = []
@@ -140,6 +191,9 @@ def clean_response(response):
140
 
141
  # Function to generate quiz based on content
142
  def generate_quiz(content):
 
 
 
143
  prompt = f"{quiz_prompt}\n\nDocument content:\n{content}"
144
  response = chat_model([HumanMessage(content=prompt)])
145
  cleaned_response = clean_response(response.content)
@@ -147,6 +201,9 @@ def generate_quiz(content):
147
 
148
  # Function to retrieve relevant documents from vectorstore based on user query
149
  def retrieve_documents(query):
 
 
 
150
  results = vectorstore.similarity_search(query, k=3)
151
  return [doc.page_content for doc in results]
152
 
@@ -173,6 +230,12 @@ def convert_to_tuple_format(chat_history):
173
  # Function to handle chatbot interactions with short-term memory
174
  def chat_with_groq(user_input, chat_history):
175
  try:
 
 
 
 
 
 
176
  # Convert message format to tuple format for processing
177
  tuple_history = convert_to_tuple_format(chat_history)
178
 
@@ -284,6 +347,10 @@ def process_document(file):
284
  encoding = detect_encoding(file.name)
285
  with open(file.name, "r", encoding=encoding, errors="replace") as f:
286
  content = f.read()
 
 
 
 
287
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
288
  documents = [Document(page_content=chunk) for chunk in text_splitter.split_text(content)]
289
  vectorstore.add_documents(documents)
@@ -312,16 +379,16 @@ def preprocess_audio(audio_data, sample_rate):
312
  audio_data = audio_data / max_val
313
 
314
  # Resample to 16kHz if needed (Whisper works best with 16kHz)
315
- if sample_rate != AUDIO_SAMPLE_RATE:
316
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=AUDIO_SAMPLE_RATE)
317
  sample_rate = AUDIO_SAMPLE_RATE
318
 
319
- # Apply noise reduction (simple high-pass filter)
320
- import scipy.signal as sp
321
- nyquist = sample_rate / 2
322
- cutoff = 80 # High-pass filter cutoff frequency in Hz
323
- b, a = sp.butter(2, cutoff/nyquist, btype='high')
324
- audio_data = sp.filtfilt(b, a, audio_data)
325
 
326
  return audio_data, sample_rate
327
 
@@ -338,6 +405,9 @@ def transcribe_audio(audio):
338
  if audio is None:
339
  return "No audio input detected."
340
 
 
 
 
341
  sample_rate, audio_data = audio
342
 
343
  # Preprocess audio
@@ -427,6 +497,12 @@ async def api_chat(message: str = Form(...)):
427
  API endpoint for chat interactions
428
  """
429
  try:
 
 
 
 
 
 
430
  # Simple chat response without memory for API
431
  prompt = f"You are a helpful AI tutor. Answer the following question accurately and concisely: {message}"
432
  response = chat_model([HumanMessage(content=prompt)])
@@ -490,7 +566,13 @@ async def api_process_document(file: UploadFile = File(...)):
490
  @app.get("/api/health")
491
  async def health_check():
492
  """Health check endpoint"""
493
- return {"status": "healthy", "timestamp": time.time()}
 
 
 
 
 
 
494
 
495
  # Clear chat history function
496
  def clear_chat_history():
@@ -598,15 +680,21 @@ def tutor_ai_chatbot():
598
  gr.Video("We_not_me_video.mp4", label="Introduction Video")
599
 
600
  # Launch the application
601
- gradio_app.launch(share=False)
602
 
603
- # Run both FastAPI and Gradio
604
  if __name__ == "__main__":
605
- import threading
606
-
607
- # Start Gradio in a separate thread
608
- gradio_thread = threading.Thread(target=tutor_ai_chatbot, daemon=True)
609
- gradio_thread.start()
610
 
611
- # Start FastAPI
612
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
16
  import groq
17
  import uuid # For generating unique filenames
18
 
19
+ # Updated imports for LangChain compatibility
20
+ try:
21
+ # For newer versions of LangChain
22
+ from langchain_groq import ChatGroq
23
+ from langchain_core.messages import HumanMessage
24
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
25
+ from langchain_community.vectorstores import Chroma
26
+ from langchain_community.embeddings import HuggingFaceEmbeddings
27
+ from langchain_core.documents import Document
28
+ except ImportError:
29
+ # Fallback for older versions
30
+ try:
31
+ from langchain_groq import ChatGroq
32
+ from langchain.schema import HumanMessage
33
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
34
+ from langchain_community.vectorstores import Chroma
35
+ from langchain_community.embeddings import HuggingFaceEmbeddings
36
+ from langchain.docstore.document import Document
37
+ except ImportError as e:
38
+ print(f"Import error: {e}")
39
+ # Minimal imports to keep the app running
40
+ import warnings
41
+ warnings.warn("Some LangChain components not available")
42
 
43
  # Importing chardet (make sure to add chardet to your requirements.txt)
44
  import chardet
 
57
  from typing import Optional
58
  import io
59
  import soundfile as sf
60
+
61
+ # Try to import librosa for audio processing, but make it optional
62
+ try:
63
+ import librosa
64
+ LIBROSA_AVAILABLE = True
65
+ except ImportError:
66
+ LIBROSA_AVAILABLE = False
67
+ print("Warning: librosa not available. Audio preprocessing will be limited.")
68
+
69
+ # Try to import scipy for audio filtering, but make it optional
70
+ try:
71
+ import scipy.signal as sp
72
+ SCIPY_AVAILABLE = True
73
+ except ImportError:
74
+ SCIPY_AVAILABLE = False
75
+ print("Warning: scipy not available. Audio filtering will be limited.")
76
 
77
  # Enhanced Whisper model for speech-to-text with better configuration
78
  try:
 
84
  stride_length_s=5,
85
  batch_size=8
86
  )
87
+ print("Loaded Whisper-small model successfully")
88
  except Exception as e:
89
  print(f"Warning: Could not load enhanced Whisper model: {e}")
90
+ try:
91
+ # Fallback to basic model
92
+ transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
93
+ print("Loaded Whisper-base model as fallback")
94
+ except Exception as e2:
95
+ print(f"Error loading any Whisper model: {e2}")
96
+ transcriber = None
97
 
98
  # Set API Key (Ensure it's stored securely in an environment variable)
99
  groq.api_key = os.getenv("GROQ_API_KEY")
100
 
101
+ # Initialize Chat Model with error handling
102
+ try:
103
+ if groq.api_key:
104
+ chat_model = ChatGroq(model_name="llama-3.3-70b-versatile", api_key=groq.api_key)
105
+ CHAT_MODEL_AVAILABLE = True
106
+ else:
107
+ print("GROQ_API_KEY not found in environment variables")
108
+ CHAT_MODEL_AVAILABLE = False
109
+ except Exception as e:
110
+ print(f"Error initializing chat model: {e}")
111
+ CHAT_MODEL_AVAILABLE = False
112
 
113
+ # Initialize Embeddings and chromaDB with error handling
114
+ try:
115
+ os.makedirs("chroma_db", exist_ok=True)
116
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
117
+ vectorstore = Chroma(
118
+ embedding_function=embedding_model,
119
+ persist_directory="chroma_db"
120
+ )
121
+ VECTORSTORE_AVAILABLE = True
122
+ except Exception as e:
123
+ print(f"Error initializing vectorstore: {e}")
124
+ VECTORSTORE_AVAILABLE = False
125
 
126
  # Short-term memory for the LLM
127
  chat_memory = []
 
191
 
192
  # Function to generate quiz based on content
193
  def generate_quiz(content):
194
+ if not CHAT_MODEL_AVAILABLE:
195
+ return "Chat model not available. Please check GROQ_API_KEY configuration."
196
+
197
  prompt = f"{quiz_prompt}\n\nDocument content:\n{content}"
198
  response = chat_model([HumanMessage(content=prompt)])
199
  cleaned_response = clean_response(response.content)
 
201
 
202
  # Function to retrieve relevant documents from vectorstore based on user query
203
  def retrieve_documents(query):
204
+ if not VECTORSTORE_AVAILABLE:
205
+ return ["Vector store not available."]
206
+
207
  results = vectorstore.similarity_search(query, k=3)
208
  return [doc.page_content for doc in results]
209
 
 
230
  # Function to handle chatbot interactions with short-term memory
231
  def chat_with_groq(user_input, chat_history):
232
  try:
233
+ if not CHAT_MODEL_AVAILABLE:
234
+ error_msg = "Chat model not available. Please check configuration."
235
+ chat_history.append({"role": "user", "content": user_input})
236
+ chat_history.append({"role": "assistant", "content": error_msg})
237
+ return chat_history, "", None
238
+
239
  # Convert message format to tuple format for processing
240
  tuple_history = convert_to_tuple_format(chat_history)
241
 
 
347
  encoding = detect_encoding(file.name)
348
  with open(file.name, "r", encoding=encoding, errors="replace") as f:
349
  content = f.read()
350
+
351
+ if not VECTORSTORE_AVAILABLE:
352
+ return f"Document processed but vector store not available. Content preview: {content[:500]}..."
353
+
354
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
355
  documents = [Document(page_content=chunk) for chunk in text_splitter.split_text(content)]
356
  vectorstore.add_documents(documents)
 
379
  audio_data = audio_data / max_val
380
 
381
  # Resample to 16kHz if needed (Whisper works best with 16kHz)
382
+ if LIBROSA_AVAILABLE and sample_rate != AUDIO_SAMPLE_RATE:
383
  audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=AUDIO_SAMPLE_RATE)
384
  sample_rate = AUDIO_SAMPLE_RATE
385
 
386
+ # Apply noise reduction (simple high-pass filter) if scipy is available
387
+ if SCIPY_AVAILABLE:
388
+ nyquist = sample_rate / 2
389
+ cutoff = 80 # High-pass filter cutoff frequency in Hz
390
+ b, a = sp.butter(2, cutoff/nyquist, btype='high')
391
+ audio_data = sp.filtfilt(b, a, audio_data)
392
 
393
  return audio_data, sample_rate
394
 
 
405
  if audio is None:
406
  return "No audio input detected."
407
 
408
+ if transcriber is None:
409
+ return "Speech-to-text service not available. Please check model configuration."
410
+
411
  sample_rate, audio_data = audio
412
 
413
  # Preprocess audio
 
497
  API endpoint for chat interactions
498
  """
499
  try:
500
+ if not CHAT_MODEL_AVAILABLE:
501
+ return JSONResponse({
502
+ "success": False,
503
+ "error": "Chat model not available"
504
+ }, status_code=503)
505
+
506
  # Simple chat response without memory for API
507
  prompt = f"You are a helpful AI tutor. Answer the following question accurately and concisely: {message}"
508
  response = chat_model([HumanMessage(content=prompt)])
 
566
  @app.get("/api/health")
567
  async def health_check():
568
  """Health check endpoint"""
569
+ return {
570
+ "status": "healthy",
571
+ "timestamp": time.time(),
572
+ "chat_model_available": CHAT_MODEL_AVAILABLE,
573
+ "vectorstore_available": VECTORSTORE_AVAILABLE,
574
+ "stt_available": transcriber is not None
575
+ }
576
 
577
  # Clear chat history function
578
  def clear_chat_history():
 
680
  gr.Video("We_not_me_video.mp4", label="Introduction Video")
681
 
682
  # Launch the application
683
+ gradio_app.launch(share=False, server_name="0.0.0.0", server_port=7860)
684
 
685
+ # Run the application based on command line arguments
686
  if __name__ == "__main__":
687
+ import sys
 
 
 
 
688
 
689
+ if len(sys.argv) > 1 and sys.argv[1] == "api":
690
+ # Run only FastAPI
691
+ print("Starting FastAPI server on port 8000...")
692
+ uvicorn.run(app, host="0.0.0.0", port=8000)
693
+ elif len(sys.argv) > 1 and sys.argv[1] == "gradio":
694
+ # Run only Gradio
695
+ print("Starting Gradio interface on port 7860...")
696
+ tutor_ai_chatbot()
697
+ else:
698
+ # Run both (Gradio in main thread for stability)
699
+ print("Starting Gradio interface (main thread) and FastAPI server (background)...")
700
+ tutor_ai_chatbot()