Spaces:
Sleeping
Sleeping
Chia Woon Yap
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,13 +16,29 @@ import time
|
|
| 16 |
import groq
|
| 17 |
import uuid # For generating unique filenames
|
| 18 |
|
| 19 |
-
# Updated imports
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
from
|
| 23 |
-
from
|
| 24 |
-
from
|
| 25 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Importing chardet (make sure to add chardet to your requirements.txt)
|
| 28 |
import chardet
|
|
@@ -41,7 +57,22 @@ import uvicorn
|
|
| 41 |
from typing import Optional
|
| 42 |
import io
|
| 43 |
import soundfile as sf
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# Enhanced Whisper model for speech-to-text with better configuration
|
| 47 |
try:
|
|
@@ -53,24 +84,44 @@ try:
|
|
| 53 |
stride_length_s=5,
|
| 54 |
batch_size=8
|
| 55 |
)
|
|
|
|
| 56 |
except Exception as e:
|
| 57 |
print(f"Warning: Could not load enhanced Whisper model: {e}")
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# Set API Key (Ensure it's stored securely in an environment variable)
|
| 62 |
groq.api_key = os.getenv("GROQ_API_KEY")
|
| 63 |
|
| 64 |
-
# Initialize Chat Model
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
# Initialize Embeddings and chromaDB
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Short-term memory for the LLM
|
| 76 |
chat_memory = []
|
|
@@ -140,6 +191,9 @@ def clean_response(response):
|
|
| 140 |
|
| 141 |
# Function to generate quiz based on content
|
| 142 |
def generate_quiz(content):
|
|
|
|
|
|
|
|
|
|
| 143 |
prompt = f"{quiz_prompt}\n\nDocument content:\n{content}"
|
| 144 |
response = chat_model([HumanMessage(content=prompt)])
|
| 145 |
cleaned_response = clean_response(response.content)
|
|
@@ -147,6 +201,9 @@ def generate_quiz(content):
|
|
| 147 |
|
| 148 |
# Function to retrieve relevant documents from vectorstore based on user query
|
| 149 |
def retrieve_documents(query):
|
|
|
|
|
|
|
|
|
|
| 150 |
results = vectorstore.similarity_search(query, k=3)
|
| 151 |
return [doc.page_content for doc in results]
|
| 152 |
|
|
@@ -173,6 +230,12 @@ def convert_to_tuple_format(chat_history):
|
|
| 173 |
# Function to handle chatbot interactions with short-term memory
|
| 174 |
def chat_with_groq(user_input, chat_history):
|
| 175 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
# Convert message format to tuple format for processing
|
| 177 |
tuple_history = convert_to_tuple_format(chat_history)
|
| 178 |
|
|
@@ -284,6 +347,10 @@ def process_document(file):
|
|
| 284 |
encoding = detect_encoding(file.name)
|
| 285 |
with open(file.name, "r", encoding=encoding, errors="replace") as f:
|
| 286 |
content = f.read()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
| 288 |
documents = [Document(page_content=chunk) for chunk in text_splitter.split_text(content)]
|
| 289 |
vectorstore.add_documents(documents)
|
|
@@ -312,16 +379,16 @@ def preprocess_audio(audio_data, sample_rate):
|
|
| 312 |
audio_data = audio_data / max_val
|
| 313 |
|
| 314 |
# Resample to 16kHz if needed (Whisper works best with 16kHz)
|
| 315 |
-
if sample_rate != AUDIO_SAMPLE_RATE:
|
| 316 |
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=AUDIO_SAMPLE_RATE)
|
| 317 |
sample_rate = AUDIO_SAMPLE_RATE
|
| 318 |
|
| 319 |
-
# Apply noise reduction (simple high-pass filter)
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
|
| 326 |
return audio_data, sample_rate
|
| 327 |
|
|
@@ -338,6 +405,9 @@ def transcribe_audio(audio):
|
|
| 338 |
if audio is None:
|
| 339 |
return "No audio input detected."
|
| 340 |
|
|
|
|
|
|
|
|
|
|
| 341 |
sample_rate, audio_data = audio
|
| 342 |
|
| 343 |
# Preprocess audio
|
|
@@ -427,6 +497,12 @@ async def api_chat(message: str = Form(...)):
|
|
| 427 |
API endpoint for chat interactions
|
| 428 |
"""
|
| 429 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
# Simple chat response without memory for API
|
| 431 |
prompt = f"You are a helpful AI tutor. Answer the following question accurately and concisely: {message}"
|
| 432 |
response = chat_model([HumanMessage(content=prompt)])
|
|
@@ -490,7 +566,13 @@ async def api_process_document(file: UploadFile = File(...)):
|
|
| 490 |
@app.get("/api/health")
|
| 491 |
async def health_check():
|
| 492 |
"""Health check endpoint"""
|
| 493 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
# Clear chat history function
|
| 496 |
def clear_chat_history():
|
|
@@ -598,15 +680,21 @@ def tutor_ai_chatbot():
|
|
| 598 |
gr.Video("We_not_me_video.mp4", label="Introduction Video")
|
| 599 |
|
| 600 |
# Launch the application
|
| 601 |
-
gradio_app.launch(share=False)
|
| 602 |
|
| 603 |
-
# Run
|
| 604 |
if __name__ == "__main__":
|
| 605 |
-
import
|
| 606 |
-
|
| 607 |
-
# Start Gradio in a separate thread
|
| 608 |
-
gradio_thread = threading.Thread(target=tutor_ai_chatbot, daemon=True)
|
| 609 |
-
gradio_thread.start()
|
| 610 |
|
| 611 |
-
|
| 612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import groq
|
| 17 |
import uuid # For generating unique filenames
|
| 18 |
|
| 19 |
+
# Updated imports for LangChain compatibility
|
| 20 |
+
try:
|
| 21 |
+
# For newer versions of LangChain
|
| 22 |
+
from langchain_groq import ChatGroq
|
| 23 |
+
from langchain_core.messages import HumanMessage
|
| 24 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 25 |
+
from langchain_community.vectorstores import Chroma
|
| 26 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 27 |
+
from langchain_core.documents import Document
|
| 28 |
+
except ImportError:
|
| 29 |
+
# Fallback for older versions
|
| 30 |
+
try:
|
| 31 |
+
from langchain_groq import ChatGroq
|
| 32 |
+
from langchain.schema import HumanMessage
|
| 33 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 34 |
+
from langchain_community.vectorstores import Chroma
|
| 35 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 36 |
+
from langchain.docstore.document import Document
|
| 37 |
+
except ImportError as e:
|
| 38 |
+
print(f"Import error: {e}")
|
| 39 |
+
# Minimal imports to keep the app running
|
| 40 |
+
import warnings
|
| 41 |
+
warnings.warn("Some LangChain components not available")
|
| 42 |
|
| 43 |
# Importing chardet (make sure to add chardet to your requirements.txt)
|
| 44 |
import chardet
|
|
|
|
| 57 |
from typing import Optional
|
| 58 |
import io
|
| 59 |
import soundfile as sf
|
| 60 |
+
|
| 61 |
+
# Try to import librosa for audio processing, but make it optional
|
| 62 |
+
try:
|
| 63 |
+
import librosa
|
| 64 |
+
LIBROSA_AVAILABLE = True
|
| 65 |
+
except ImportError:
|
| 66 |
+
LIBROSA_AVAILABLE = False
|
| 67 |
+
print("Warning: librosa not available. Audio preprocessing will be limited.")
|
| 68 |
+
|
| 69 |
+
# Try to import scipy for audio filtering, but make it optional
|
| 70 |
+
try:
|
| 71 |
+
import scipy.signal as sp
|
| 72 |
+
SCIPY_AVAILABLE = True
|
| 73 |
+
except ImportError:
|
| 74 |
+
SCIPY_AVAILABLE = False
|
| 75 |
+
print("Warning: scipy not available. Audio filtering will be limited.")
|
| 76 |
|
| 77 |
# Enhanced Whisper model for speech-to-text with better configuration
|
| 78 |
try:
|
|
|
|
| 84 |
stride_length_s=5,
|
| 85 |
batch_size=8
|
| 86 |
)
|
| 87 |
+
print("Loaded Whisper-small model successfully")
|
| 88 |
except Exception as e:
|
| 89 |
print(f"Warning: Could not load enhanced Whisper model: {e}")
|
| 90 |
+
try:
|
| 91 |
+
# Fallback to basic model
|
| 92 |
+
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
|
| 93 |
+
print("Loaded Whisper-base model as fallback")
|
| 94 |
+
except Exception as e2:
|
| 95 |
+
print(f"Error loading any Whisper model: {e2}")
|
| 96 |
+
transcriber = None
|
| 97 |
|
| 98 |
# Set API Key (Ensure it's stored securely in an environment variable)
|
| 99 |
groq.api_key = os.getenv("GROQ_API_KEY")
|
| 100 |
|
| 101 |
+
# Initialize Chat Model with error handling
|
| 102 |
+
try:
|
| 103 |
+
if groq.api_key:
|
| 104 |
+
chat_model = ChatGroq(model_name="llama-3.3-70b-versatile", api_key=groq.api_key)
|
| 105 |
+
CHAT_MODEL_AVAILABLE = True
|
| 106 |
+
else:
|
| 107 |
+
print("GROQ_API_KEY not found in environment variables")
|
| 108 |
+
CHAT_MODEL_AVAILABLE = False
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"Error initializing chat model: {e}")
|
| 111 |
+
CHAT_MODEL_AVAILABLE = False
|
| 112 |
|
| 113 |
+
# Initialize Embeddings and chromaDB with error handling
|
| 114 |
+
try:
|
| 115 |
+
os.makedirs("chroma_db", exist_ok=True)
|
| 116 |
+
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 117 |
+
vectorstore = Chroma(
|
| 118 |
+
embedding_function=embedding_model,
|
| 119 |
+
persist_directory="chroma_db"
|
| 120 |
+
)
|
| 121 |
+
VECTORSTORE_AVAILABLE = True
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Error initializing vectorstore: {e}")
|
| 124 |
+
VECTORSTORE_AVAILABLE = False
|
| 125 |
|
| 126 |
# Short-term memory for the LLM
|
| 127 |
chat_memory = []
|
|
|
|
| 191 |
|
| 192 |
# Function to generate quiz based on content
|
| 193 |
def generate_quiz(content):
|
| 194 |
+
if not CHAT_MODEL_AVAILABLE:
|
| 195 |
+
return "Chat model not available. Please check GROQ_API_KEY configuration."
|
| 196 |
+
|
| 197 |
prompt = f"{quiz_prompt}\n\nDocument content:\n{content}"
|
| 198 |
response = chat_model([HumanMessage(content=prompt)])
|
| 199 |
cleaned_response = clean_response(response.content)
|
|
|
|
| 201 |
|
| 202 |
# Function to retrieve relevant documents from vectorstore based on user query
|
| 203 |
def retrieve_documents(query):
|
| 204 |
+
if not VECTORSTORE_AVAILABLE:
|
| 205 |
+
return ["Vector store not available."]
|
| 206 |
+
|
| 207 |
results = vectorstore.similarity_search(query, k=3)
|
| 208 |
return [doc.page_content for doc in results]
|
| 209 |
|
|
|
|
| 230 |
# Function to handle chatbot interactions with short-term memory
|
| 231 |
def chat_with_groq(user_input, chat_history):
|
| 232 |
try:
|
| 233 |
+
if not CHAT_MODEL_AVAILABLE:
|
| 234 |
+
error_msg = "Chat model not available. Please check configuration."
|
| 235 |
+
chat_history.append({"role": "user", "content": user_input})
|
| 236 |
+
chat_history.append({"role": "assistant", "content": error_msg})
|
| 237 |
+
return chat_history, "", None
|
| 238 |
+
|
| 239 |
# Convert message format to tuple format for processing
|
| 240 |
tuple_history = convert_to_tuple_format(chat_history)
|
| 241 |
|
|
|
|
| 347 |
encoding = detect_encoding(file.name)
|
| 348 |
with open(file.name, "r", encoding=encoding, errors="replace") as f:
|
| 349 |
content = f.read()
|
| 350 |
+
|
| 351 |
+
if not VECTORSTORE_AVAILABLE:
|
| 352 |
+
return f"Document processed but vector store not available. Content preview: {content[:500]}..."
|
| 353 |
+
|
| 354 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
| 355 |
documents = [Document(page_content=chunk) for chunk in text_splitter.split_text(content)]
|
| 356 |
vectorstore.add_documents(documents)
|
|
|
|
| 379 |
audio_data = audio_data / max_val
|
| 380 |
|
| 381 |
# Resample to 16kHz if needed (Whisper works best with 16kHz)
|
| 382 |
+
if LIBROSA_AVAILABLE and sample_rate != AUDIO_SAMPLE_RATE:
|
| 383 |
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=AUDIO_SAMPLE_RATE)
|
| 384 |
sample_rate = AUDIO_SAMPLE_RATE
|
| 385 |
|
| 386 |
+
# Apply noise reduction (simple high-pass filter) if scipy is available
|
| 387 |
+
if SCIPY_AVAILABLE:
|
| 388 |
+
nyquist = sample_rate / 2
|
| 389 |
+
cutoff = 80 # High-pass filter cutoff frequency in Hz
|
| 390 |
+
b, a = sp.butter(2, cutoff/nyquist, btype='high')
|
| 391 |
+
audio_data = sp.filtfilt(b, a, audio_data)
|
| 392 |
|
| 393 |
return audio_data, sample_rate
|
| 394 |
|
|
|
|
| 405 |
if audio is None:
|
| 406 |
return "No audio input detected."
|
| 407 |
|
| 408 |
+
if transcriber is None:
|
| 409 |
+
return "Speech-to-text service not available. Please check model configuration."
|
| 410 |
+
|
| 411 |
sample_rate, audio_data = audio
|
| 412 |
|
| 413 |
# Preprocess audio
|
|
|
|
| 497 |
API endpoint for chat interactions
|
| 498 |
"""
|
| 499 |
try:
|
| 500 |
+
if not CHAT_MODEL_AVAILABLE:
|
| 501 |
+
return JSONResponse({
|
| 502 |
+
"success": False,
|
| 503 |
+
"error": "Chat model not available"
|
| 504 |
+
}, status_code=503)
|
| 505 |
+
|
| 506 |
# Simple chat response without memory for API
|
| 507 |
prompt = f"You are a helpful AI tutor. Answer the following question accurately and concisely: {message}"
|
| 508 |
response = chat_model([HumanMessage(content=prompt)])
|
|
|
|
| 566 |
@app.get("/api/health")
|
| 567 |
async def health_check():
|
| 568 |
"""Health check endpoint"""
|
| 569 |
+
return {
|
| 570 |
+
"status": "healthy",
|
| 571 |
+
"timestamp": time.time(),
|
| 572 |
+
"chat_model_available": CHAT_MODEL_AVAILABLE,
|
| 573 |
+
"vectorstore_available": VECTORSTORE_AVAILABLE,
|
| 574 |
+
"stt_available": transcriber is not None
|
| 575 |
+
}
|
| 576 |
|
| 577 |
# Clear chat history function
|
| 578 |
def clear_chat_history():
|
|
|
|
| 680 |
gr.Video("We_not_me_video.mp4", label="Introduction Video")
|
| 681 |
|
| 682 |
# Launch the application
|
| 683 |
+
gradio_app.launch(share=False, server_name="0.0.0.0", server_port=7860)
|
| 684 |
|
| 685 |
+
# Run the application based on command line arguments
|
| 686 |
if __name__ == "__main__":
|
| 687 |
+
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
+
if len(sys.argv) > 1 and sys.argv[1] == "api":
|
| 690 |
+
# Run only FastAPI
|
| 691 |
+
print("Starting FastAPI server on port 8000...")
|
| 692 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 693 |
+
elif len(sys.argv) > 1 and sys.argv[1] == "gradio":
|
| 694 |
+
# Run only Gradio
|
| 695 |
+
print("Starting Gradio interface on port 7860...")
|
| 696 |
+
tutor_ai_chatbot()
|
| 697 |
+
else:
|
| 698 |
+
# Run both (Gradio in main thread for stability)
|
| 699 |
+
print("Starting Gradio interface (main thread) and FastAPI server (background)...")
|
| 700 |
+
tutor_ai_chatbot()
|