Spaces:
Build error
Build error
| import os | |
| import re | |
| import uuid | |
| import time # Add this | |
| import tempfile | |
| import numpy as np | |
| import gradio as gr | |
| import chardet | |
| import fitz # PyMuPDF | |
| import docx | |
| import gtts | |
| from pptx import Presentation | |
| from typing import TypedDict, List | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage | |
| from langgraph.graph import StateGraph, END | |
| from langchain_groq import ChatGroq | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_core.documents import Document | |
| # --- 1. INITIALIZATION & CORE TOOLS --- | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| chat_model = ChatGroq(model_name="llama-3.3-70b-versatile", api_key=groq_api_key) | |
| web_search_tool = DuckDuckGoSearchRun() | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| vectorstore = Chroma(embedding_function=embedding_model, persist_directory="chroma_db") | |
| # --- 2. HELPER FUNCTIONS --- | |
| def clean_response(response): | |
| """Remove <think>...</think> blocks and common markdown artifacts.""" | |
| # Remove think tags and their content (greedily, case-insensitive) | |
| cleaned = re.sub(r"<think>.*?(?:</think>|$)", "", response, flags=re.DOTALL | re.IGNORECASE) | |
| # Remove stray closing tags and markdown symbols | |
| cleaned = re.sub(r"</?think>|\*\*|\*|\[|\]|#", "", cleaned) | |
| return cleaned.strip() | |
| #return cleaned_text.strip() | |
| def retrieve_documents(query): | |
| results = vectorstore.similarity_search(query, k=3) | |
| return [doc.page_content for doc in results] | |
| def speech_playback(text): | |
| try: | |
| unique_id = str(uuid.uuid4()) | |
| audio_file = f"/content/output_audio_{unique_id}.mp3" | |
| tts = gtts.gTTS(text[:500], lang='en') | |
| tts.save(audio_file) | |
| return audio_file | |
| except Exception as e: | |
| print(f"TTS error: {e}") | |
| return None | |
| # --- 3. DOCUMENT INGESTION FUNCTION --- | |
| def extract_and_store_document(file_path: str): | |
| text = "" | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| try: | |
| if file_ext == ".pdf": | |
| doc = fitz.open(file_path) | |
| for page in doc: | |
| text += page.get_text() | |
| doc.close() | |
| elif file_ext == ".docx": | |
| doc = docx.Document(file_path) | |
| text = "\n".join([para.text for para in doc.paragraphs]) | |
| elif file_ext == ".pptx": | |
| prs = Presentation(file_path) | |
| for slide in prs.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text"): | |
| text += shape.text + "\n" | |
| else: | |
| with open(file_path, 'rb') as f: | |
| raw_data = f.read() | |
| encoding = chardet.detect(raw_data)['encoding'] or 'utf-8' | |
| text = raw_data.decode(encoding, errors='ignore') | |
| if not text.strip(): | |
| return False | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| chunks = splitter.split_text(text) | |
| documents = [Document(page_content=chunk, metadata={"source": os.path.basename(file_path)}) for chunk in chunks] | |
| # Chroma auto-persists in version 0.4.x+ | |
| vectorstore.add_documents(documents) | |
| # REMOVE THIS LINE: vectorstore.persist() # Delete line 93 | |
| return True | |
| except Exception as e: | |
| print(f"Error processing {file_path}: {e}") | |
| return False | |
| # --- 4. REFRAG MULTI-AGENT LOGIC (LangGraph) --- | |
| class AgentState(TypedDict): | |
| messages: List[BaseMessage] | |
| context: str | |
| decision: str | |
| source: str | |
| def sensing_node(state: AgentState): | |
| user_query = state["messages"][-1].content | |
| relevant_docs = retrieve_documents(user_query) | |
| context = "\n".join(relevant_docs) if relevant_docs else "" | |
| prompt = f"Docs: {context}\nQuery: {user_query}\nIf docs answer this, reply 'RAG'. Else reply 'WEB'." | |
| decision = chat_model.invoke([HumanMessage(content=prompt)]).content.strip().upper() | |
| return {"context": context, "decision": "RAG" if "RAG" in decision else "WEB"} | |
| #Alternative: Better Approach - Add Fallback Search Strategy | |
| #Add this function for more robust searching: | |
| def safe_web_search_with_fallback(query: str): | |
| """Web search with multiple fallback strategies""" | |
| global last_web_search_time | |
| strategies = [ | |
| # Strategy 1: Direct search | |
| lambda: web_search_tool.run(query), | |
| # Strategy 2: Search with simplified query | |
| lambda: web_search_tool.run(query.split("?")[0] if "?" in query else query), | |
| # Strategy 3: Search with keywords only | |
| lambda: web_search_tool.run(' '.join(query.split()[:10])) | |
| ] | |
| for i, strategy in enumerate(strategies): | |
| try: | |
| # Rate limiting check | |
| current_time = time.time() | |
| if current_time - last_web_search_time < 5: # 5 second cooldown | |
| time.sleep(5 - (current_time - last_web_search_time)) | |
| result = strategy() | |
| last_web_search_time = time.time() | |
| if result and len(result) > 50: # Valid result | |
| return result[:2000] # Truncate | |
| except Exception as e: | |
| if i == len(strategies) - 1: # Last strategy failed | |
| return f"Web search unavailable. Error: {str(e)[:100]}" | |
| continue | |
| return "Web search temporarily unavailable." | |
| # Add global variable for rate limiting | |
| last_web_search_time = 0 | |
| WEB_SEARCH_COOLDOWN = 10 # 10 seconds between web searches | |
| def expansion_node(state: AgentState): | |
| global last_web_search_time | |
| if state["decision"] == "WEB": | |
| user_query = state["messages"][-1].content | |
| web_data = safe_web_search_with_fallback(user_query) | |
| return { | |
| "context": f"WEB INFO: {web_data}\nLOCAL: {state['context']}", | |
| "source": "Web + Local Documents" | |
| } | |
| return {"source": "Local Documents Only"} | |
| # Implement rate limiting | |
| current_time = time.time() | |
| time_since_last = current_time - last_web_search_time | |
| # If we searched recently, wait or skip web search | |
| if time_since_last < WEB_SEARCH_COOLDOWN: | |
| # Option 1: Skip web search and use local docs only | |
| # return {"context": state['context'], "source": "Local Documents Only (Rate limited)"} | |
| # Option 2: Wait and then search (for demo) | |
| wait_time = WEB_SEARCH_COOLDOWN - time_since_last | |
| time.sleep(wait_time) | |
| try: | |
| web_data = web_search_tool.run(user_query) | |
| last_web_search_time = time.time() # Update timestamp | |
| # Truncate web data to avoid context overflow | |
| if len(web_data) > 1500: | |
| web_data = web_data[:1500] + "..." | |
| return { | |
| "context": f"WEB SEARCH RESULTS: {web_data}\nLOCAL DOCUMENTS: {state['context']}", | |
| "source": "Web Search + Local Documents" | |
| } | |
| except Exception as e: | |
| # If web search fails, use local docs with explanation | |
| error_msg = str(e) | |
| if "Ratelimit" in error_msg: | |
| return { | |
| "context": state['context'], | |
| "source": "Local Documents Only (Search rate limit reached)" | |
| } | |
| else: | |
| return { | |
| "context": state['context'], | |
| "source": f"Local Documents Only (Search error: {error_msg[:100]})" | |
| } | |
| return {"source": "Local Documents Only"} | |
| def generation_node(state: AgentState): | |
| system_msg = f"You are a Tutor AI. Use this context: {state['context']}" | |
| response = chat_model.invoke([SystemMessage(content=system_msg)] + state["messages"]) | |
| cleaned = clean_response(response.content) | |
| return {"messages": [AIMessage(content=f"{cleaned}\n\n*(Verified via: {state['source']})*")]} | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("sense", sensing_node) | |
| workflow.add_node("expand", expansion_node) | |
| workflow.add_node("generate", generation_node) | |
| workflow.set_entry_point("sense") | |
| workflow.add_edge("sense", "expand") | |
| workflow.add_edge("expand", "generate") | |
| workflow.add_edge("generate", END) | |
| app_agent = workflow.compile() | |
| # --- 5. GRADIO APP WITH MANUAL AUDIO --- | |
| # Store last assistant response globally (simple approach for demo) | |
| last_assistant_response = "" | |
| def chat_handler(user_input, chat_history): | |
| global last_assistant_response | |
| if not user_input: | |
| return chat_history, "", None | |
| inputs = {"messages": [HumanMessage(content=user_input)], "context": "", "decision": "", "source": ""} | |
| result = app_agent.invoke(inputs) | |
| final_msg = result["messages"][-1].content | |
| chat_history.append({"role": "user", "content": user_input}) | |
| chat_history.append({"role": "assistant", "content": final_msg}) | |
| # Save clean text for later TTS (without source note) | |
| last_assistant_response = final_msg.split("*(Verified")[0].strip() | |
| # Return chat history and clear audio (no autoplay) | |
| return chat_history, "", None | |
| def generate_audio(): | |
| global last_assistant_response | |
| if not last_assistant_response: | |
| return None | |
| return speech_playback(last_assistant_response) | |
| def upload_file(file): | |
| if file is None: | |
| return "โ No file uploaded." | |
| try: | |
| success = extract_and_store_document(file.name) | |
| if success: | |
| return f"โ **{os.path.basename(file.name)}** successfully parsed and added to knowledge base!" | |
| else: | |
| return f"โ ๏ธ Failed to extract text from **{os.path.basename(file.name)}**." | |
| except Exception as e: | |
| return f"โ Error: {str(e)}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# ๐ REFRAG Multi-Agent Tutor") | |
| with gr.Tab("AI Chatbot"): | |
| #chatbot = gr.Chatbot(type="messages", height=400) | |
| chatbot = gr.Chatbot(value=[], height=400) | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Ask your tutor...", scale=4) | |
| submit = gr.Button("Send", variant="primary") | |
| # Manual audio control | |
| with gr.Row(): | |
| play_audio_btn = gr.Button("๐ Play Audio Response", variant="secondary") | |
| audio_out = gr.Audio(label="Audio Response", autoplay=False) # autoplay=False | |
| # Chat submission | |
| submit.click(chat_handler, [msg, chatbot], [chatbot, msg, audio_out]) | |
| msg.submit(chat_handler, [msg, chatbot], [chatbot, msg, audio_out]) | |
| # Manual audio generation | |
| play_audio_btn.click(generate_audio, None, audio_out) | |
| with gr.Tab("Upload Notes"): | |
| file_input = gr.File(label="Upload PDF / DOCX / PPTX / TXT", file_types=[".pdf", ".docx", ".pptx", ".txt"]) | |
| upload_status = gr.Markdown() | |
| file_input.change(upload_file, file_input, upload_status) | |
| demo.launch(share=True, debug=True) |