Test_TutorAI_We / app.py
Lesterchia1's picture
Update app.py
36698b9 verified
import os
import re
import uuid
import time # Add this
import tempfile
import numpy as np
import gradio as gr
import chardet
import fitz # PyMuPDF
import docx
import gtts
from pptx import Presentation
from typing import TypedDict, List
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.graph import StateGraph, END
from langchain_groq import ChatGroq
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
# --- 1. INITIALIZATION & CORE TOOLS ---
groq_api_key = os.getenv("GROQ_API_KEY")
chat_model = ChatGroq(model_name="llama-3.3-70b-versatile", api_key=groq_api_key)
web_search_tool = DuckDuckGoSearchRun()
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = Chroma(embedding_function=embedding_model, persist_directory="chroma_db")
# --- 2. HELPER FUNCTIONS ---
def clean_response(response):
"""Remove <think>...</think> blocks and common markdown artifacts."""
# Remove think tags and their content (greedily, case-insensitive)
cleaned = re.sub(r"<think>.*?(?:</think>|$)", "", response, flags=re.DOTALL | re.IGNORECASE)
# Remove stray closing tags and markdown symbols
cleaned = re.sub(r"</?think>|\*\*|\*|\[|\]|#", "", cleaned)
return cleaned.strip()
#return cleaned_text.strip()
def retrieve_documents(query):
results = vectorstore.similarity_search(query, k=3)
return [doc.page_content for doc in results]
def speech_playback(text):
try:
unique_id = str(uuid.uuid4())
audio_file = f"/content/output_audio_{unique_id}.mp3"
tts = gtts.gTTS(text[:500], lang='en')
tts.save(audio_file)
return audio_file
except Exception as e:
print(f"TTS error: {e}")
return None
# --- 3. DOCUMENT INGESTION FUNCTION ---
def extract_and_store_document(file_path: str):
text = ""
file_ext = os.path.splitext(file_path)[1].lower()
try:
if file_ext == ".pdf":
doc = fitz.open(file_path)
for page in doc:
text += page.get_text()
doc.close()
elif file_ext == ".docx":
doc = docx.Document(file_path)
text = "\n".join([para.text for para in doc.paragraphs])
elif file_ext == ".pptx":
prs = Presentation(file_path)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text += shape.text + "\n"
else:
with open(file_path, 'rb') as f:
raw_data = f.read()
encoding = chardet.detect(raw_data)['encoding'] or 'utf-8'
text = raw_data.decode(encoding, errors='ignore')
if not text.strip():
return False
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = splitter.split_text(text)
documents = [Document(page_content=chunk, metadata={"source": os.path.basename(file_path)}) for chunk in chunks]
# Chroma auto-persists in version 0.4.x+
vectorstore.add_documents(documents)
# REMOVE THIS LINE: vectorstore.persist() # Delete line 93
return True
except Exception as e:
print(f"Error processing {file_path}: {e}")
return False
# --- 4. REFRAG MULTI-AGENT LOGIC (LangGraph) ---
class AgentState(TypedDict):
messages: List[BaseMessage]
context: str
decision: str
source: str
def sensing_node(state: AgentState):
user_query = state["messages"][-1].content
relevant_docs = retrieve_documents(user_query)
context = "\n".join(relevant_docs) if relevant_docs else ""
prompt = f"Docs: {context}\nQuery: {user_query}\nIf docs answer this, reply 'RAG'. Else reply 'WEB'."
decision = chat_model.invoke([HumanMessage(content=prompt)]).content.strip().upper()
return {"context": context, "decision": "RAG" if "RAG" in decision else "WEB"}
#Alternative: Better Approach - Add Fallback Search Strategy
#Add this function for more robust searching:
def safe_web_search_with_fallback(query: str):
"""Web search with multiple fallback strategies"""
global last_web_search_time
strategies = [
# Strategy 1: Direct search
lambda: web_search_tool.run(query),
# Strategy 2: Search with simplified query
lambda: web_search_tool.run(query.split("?")[0] if "?" in query else query),
# Strategy 3: Search with keywords only
lambda: web_search_tool.run(' '.join(query.split()[:10]))
]
for i, strategy in enumerate(strategies):
try:
# Rate limiting check
current_time = time.time()
if current_time - last_web_search_time < 5: # 5 second cooldown
time.sleep(5 - (current_time - last_web_search_time))
result = strategy()
last_web_search_time = time.time()
if result and len(result) > 50: # Valid result
return result[:2000] # Truncate
except Exception as e:
if i == len(strategies) - 1: # Last strategy failed
return f"Web search unavailable. Error: {str(e)[:100]}"
continue
return "Web search temporarily unavailable."
# Add global variable for rate limiting
last_web_search_time = 0
WEB_SEARCH_COOLDOWN = 10 # 10 seconds between web searches
def expansion_node(state: AgentState):
global last_web_search_time
if state["decision"] == "WEB":
user_query = state["messages"][-1].content
web_data = safe_web_search_with_fallback(user_query)
return {
"context": f"WEB INFO: {web_data}\nLOCAL: {state['context']}",
"source": "Web + Local Documents"
}
return {"source": "Local Documents Only"}
# Implement rate limiting
current_time = time.time()
time_since_last = current_time - last_web_search_time
# If we searched recently, wait or skip web search
if time_since_last < WEB_SEARCH_COOLDOWN:
# Option 1: Skip web search and use local docs only
# return {"context": state['context'], "source": "Local Documents Only (Rate limited)"}
# Option 2: Wait and then search (for demo)
wait_time = WEB_SEARCH_COOLDOWN - time_since_last
time.sleep(wait_time)
try:
web_data = web_search_tool.run(user_query)
last_web_search_time = time.time() # Update timestamp
# Truncate web data to avoid context overflow
if len(web_data) > 1500:
web_data = web_data[:1500] + "..."
return {
"context": f"WEB SEARCH RESULTS: {web_data}\nLOCAL DOCUMENTS: {state['context']}",
"source": "Web Search + Local Documents"
}
except Exception as e:
# If web search fails, use local docs with explanation
error_msg = str(e)
if "Ratelimit" in error_msg:
return {
"context": state['context'],
"source": "Local Documents Only (Search rate limit reached)"
}
else:
return {
"context": state['context'],
"source": f"Local Documents Only (Search error: {error_msg[:100]})"
}
return {"source": "Local Documents Only"}
def generation_node(state: AgentState):
system_msg = f"You are a Tutor AI. Use this context: {state['context']}"
response = chat_model.invoke([SystemMessage(content=system_msg)] + state["messages"])
cleaned = clean_response(response.content)
return {"messages": [AIMessage(content=f"{cleaned}\n\n*(Verified via: {state['source']})*")]}
workflow = StateGraph(AgentState)
workflow.add_node("sense", sensing_node)
workflow.add_node("expand", expansion_node)
workflow.add_node("generate", generation_node)
workflow.set_entry_point("sense")
workflow.add_edge("sense", "expand")
workflow.add_edge("expand", "generate")
workflow.add_edge("generate", END)
app_agent = workflow.compile()
# --- 5. GRADIO APP WITH MANUAL AUDIO ---
# Store last assistant response globally (simple approach for demo)
last_assistant_response = ""
def chat_handler(user_input, chat_history):
global last_assistant_response
if not user_input:
return chat_history, "", None
inputs = {"messages": [HumanMessage(content=user_input)], "context": "", "decision": "", "source": ""}
result = app_agent.invoke(inputs)
final_msg = result["messages"][-1].content
chat_history.append({"role": "user", "content": user_input})
chat_history.append({"role": "assistant", "content": final_msg})
# Save clean text for later TTS (without source note)
last_assistant_response = final_msg.split("*(Verified")[0].strip()
# Return chat history and clear audio (no autoplay)
return chat_history, "", None
def generate_audio():
global last_assistant_response
if not last_assistant_response:
return None
return speech_playback(last_assistant_response)
def upload_file(file):
if file is None:
return "โŒ No file uploaded."
try:
success = extract_and_store_document(file.name)
if success:
return f"โœ… **{os.path.basename(file.name)}** successfully parsed and added to knowledge base!"
else:
return f"โš ๏ธ Failed to extract text from **{os.path.basename(file.name)}**."
except Exception as e:
return f"โŒ Error: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown("# ๐ŸŽ“ REFRAG Multi-Agent Tutor")
with gr.Tab("AI Chatbot"):
#chatbot = gr.Chatbot(type="messages", height=400)
chatbot = gr.Chatbot(value=[], height=400)
with gr.Row():
msg = gr.Textbox(placeholder="Ask your tutor...", scale=4)
submit = gr.Button("Send", variant="primary")
# Manual audio control
with gr.Row():
play_audio_btn = gr.Button("๐Ÿ”Š Play Audio Response", variant="secondary")
audio_out = gr.Audio(label="Audio Response", autoplay=False) # autoplay=False
# Chat submission
submit.click(chat_handler, [msg, chatbot], [chatbot, msg, audio_out])
msg.submit(chat_handler, [msg, chatbot], [chatbot, msg, audio_out])
# Manual audio generation
play_audio_btn.click(generate_audio, None, audio_out)
with gr.Tab("Upload Notes"):
file_input = gr.File(label="Upload PDF / DOCX / PPTX / TXT", file_types=[".pdf", ".docx", ".pptx", ".txt"])
upload_status = gr.Markdown()
file_input.change(upload_file, file_input, upload_status)
demo.launch(share=True, debug=True)