Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Multi-LLM Agent System with Supabase FAISS Integration | |
| Complete system for document insertion, retrieval, and question answering | |
| """ | |
| import os | |
| import time | |
| import random | |
| import operator | |
| from typing import List, Dict, Any, TypedDict, Annotated, Optional | |
| from dotenv import load_dotenv | |
| from langchain_core.tools import tool | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from langchain_groq import ChatGroq | |
| # Supabase and FAISS imports | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from supabase import create_client, Client | |
| import pandas as pd | |
| import json | |
| import pickle | |
| load_dotenv() | |
| # Enhanced system prompt for question-answering | |
| ENHANCED_SYSTEM_PROMPT = ( | |
| "You are a helpful assistant tasked with answering questions using a set of tools. " | |
| "You must provide accurate, comprehensive answers based on available information. " | |
| "When answering questions, follow these guidelines:\n" | |
| "1. Use available tools to gather information when needed\n" | |
| "2. Provide precise, factual answers\n" | |
| "3. For numbers: don't use commas or units unless specified\n" | |
| "4. For strings: don't use articles or abbreviations, write digits in plain text\n" | |
| "5. For lists: apply above rules based on element type\n" | |
| "6. Always end with 'FINAL ANSWER: [YOUR ANSWER]'\n" | |
| "7. Be concise but thorough in your reasoning\n" | |
| "8. If you cannot find the answer, state that clearly" | |
| ) | |
| # ---- Tool Definitions ---- | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply two integers and return the product.""" | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add two integers and return the sum.""" | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtract the second integer from the first and return the difference.""" | |
| return a - b | |
| def divide(a: int, b: int) -> float: | |
| """Divide the first integer by the second and return the quotient.""" | |
| if b == 0: | |
| raise ValueError("Cannot divide by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| """Return the remainder when dividing the first integer by the second.""" | |
| return a % b | |
| def optimized_web_search(query: str) -> str: | |
| """Perform an optimized web search using TavilySearchResults.""" | |
| try: | |
| time.sleep(random.uniform(0.7, 1.5)) | |
| search_tool = TavilySearchResults(max_results=3) | |
| docs = search_tool.invoke({"query": query}) | |
| return "\n\n---\n\n".join( | |
| f"<Doc url='{d.get('url','')}'>{d.get('content','')[:800]}</Doc>" | |
| for d in docs | |
| ) | |
| except Exception as e: | |
| return f"Web search failed: {e}" | |
| def optimized_wiki_search(query: str) -> str: | |
| """Perform an optimized Wikipedia search and return content snippets.""" | |
| try: | |
| time.sleep(random.uniform(0.3, 1)) | |
| docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
| return "\n\n---\n\n".join( | |
| f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:1000]}</Doc>" | |
| for d in docs | |
| ) | |
| except Exception as e: | |
| return f"Wikipedia search failed: {e}" | |
| # ---- Supabase FAISS Vector Database Integration ---- | |
| class SupabaseFAISSVectorDB: | |
| """Enhanced vector database combining FAISS with Supabase for persistent storage""" | |
| def __init__(self): | |
| # Initialize Supabase client | |
| self.supabase_url = os.getenv("SUPABASE_URL") | |
| self.supabase_key = os.getenv("SUPABASE_SERVICE_KEY") | |
| if self.supabase_url and self.supabase_key: | |
| self.supabase: Client = create_client(self.supabase_url, self.supabase_key) | |
| else: | |
| self.supabase = None | |
| print("Supabase credentials not found, running without vector database") | |
| # Initialize embedding model | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension() | |
| # Initialize FAISS index | |
| self.index = faiss.IndexFlatL2(self.embedding_dim) | |
| self.document_store = [] # Local cache for documents | |
| def insert_question_data(self, data: Dict[str, Any]) -> bool: | |
| """Insert question data into both Supabase and FAISS""" | |
| try: | |
| question_text = data.get("Question", "") | |
| embedding = self.embedding_model.encode([question_text])[0] | |
| # Insert into Supabase if available | |
| if self.supabase: | |
| question_data = { | |
| "task_id": data.get("task_id"), | |
| "question": question_text, | |
| "final_answer": data.get("Final answer"), | |
| "level": data.get("Level"), | |
| "file_name": data.get("file_name", ""), | |
| "embedding": embedding.tolist() | |
| } | |
| self.supabase.table("questions").insert(question_data).execute() | |
| # Add to local FAISS index | |
| self.index.add(embedding.reshape(1, -1).astype('float32')) | |
| self.document_store.append({ | |
| "task_id": data.get("task_id"), | |
| "question": question_text, | |
| "answer": data.get("Final answer"), | |
| "level": data.get("Level") | |
| }) | |
| return True | |
| except Exception as e: | |
| print(f"Error inserting data: {e}") | |
| return False | |
| def search_similar_questions(self, query: str, k: int = 3) -> List[Dict[str, Any]]: | |
| """Search for similar questions using vector similarity""" | |
| try: | |
| if self.index.ntotal == 0: | |
| return [] | |
| query_embedding = self.embedding_model.encode([query])[0] | |
| k = min(k, self.index.ntotal) | |
| distances, indices = self.index.search( | |
| query_embedding.reshape(1, -1).astype('float32'), k | |
| ) | |
| results = [] | |
| for i, idx in enumerate(indices[0]): | |
| if 0 <= idx < len(self.document_store): | |
| doc = self.document_store[idx] | |
| results.append({ | |
| "task_id": doc["task_id"], | |
| "question": doc["question"], | |
| "answer": doc["answer"], | |
| "similarity_score": 1 / (1 + distances[0][i]), | |
| "distance": float(distances[0][i]) | |
| }) | |
| return results | |
| except Exception as e: | |
| print(f"Error searching similar questions: {e}") | |
| return [] | |
| # ---- Enhanced Agent State ---- | |
| class EnhancedAgentState(TypedDict): | |
| """State structure for the enhanced multi-LLM agent system.""" | |
| messages: Annotated[List[HumanMessage | AIMessage], operator.add] | |
| query: str | |
| agent_type: str | |
| final_answer: str | |
| perf: Dict[str, Any] | |
| agno_resp: str | |
| tools_used: List[str] | |
| reasoning: str | |
| similar_questions: List[Dict[str, Any]] | |
| # ---- Enhanced Multi-LLM System ---- | |
| class HybridLangGraphMultiLLMSystem: | |
| """ | |
| Advanced question-answering system with multi-LLM support and vector database integration | |
| """ | |
| def __init__(self, provider="groq"): | |
| self.provider = provider | |
| self.tools = [ | |
| multiply, add, subtract, divide, modulus, | |
| optimized_web_search, optimized_wiki_search | |
| ] | |
| # Initialize vector database | |
| self.vector_db = SupabaseFAISSVectorDB() | |
| self.graph = self._build_graph() | |
| def _llm(self, model_name: str) -> ChatGroq: | |
| """Create a Groq LLM instance.""" | |
| return ChatGroq( | |
| model=model_name, | |
| temperature=0, | |
| api_key=os.getenv("GROQ_API_KEY") | |
| ) | |
| def _build_graph(self) -> StateGraph: | |
| """Build the LangGraph state machine with enhanced capabilities.""" | |
| # Initialize LLMs | |
| llama8_llm = self._llm("llama3-8b-8192") | |
| llama70_llm = self._llm("llama3-70b-8192") | |
| deepseek_llm = self._llm("deepseek-chat") | |
| def router(st: EnhancedAgentState) -> EnhancedAgentState: | |
| """Route queries to appropriate LLM based on complexity and content analysis.""" | |
| q = st["query"].lower() | |
| # Enhanced routing logic | |
| if any(keyword in q for keyword in ["calculate", "compute", "math", "multiply", "add", "subtract", "divide"]): | |
| t = "llama70" # Use more powerful model for calculations | |
| elif any(keyword in q for keyword in ["search", "find", "lookup", "wikipedia", "information about"]): | |
| t = "search_enhanced" # Use search-enhanced processing | |
| elif "deepseek" in q or any(keyword in q for keyword in ["analyze", "reasoning", "complex"]): | |
| t = "deepseek" | |
| elif "llama-8" in q: | |
| t = "llama8" | |
| elif len(q.split()) > 20: # Complex queries | |
| t = "llama70" | |
| else: | |
| t = "llama8" # Default for simple queries | |
| # Search for similar questions | |
| similar_questions = self.vector_db.search_similar_questions(st["query"], k=3) | |
| return {**st, "agent_type": t, "tools_used": [], "reasoning": "", "similar_questions": similar_questions} | |
| def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| """Process query with Llama-3 8B model.""" | |
| t0 = time.time() | |
| try: | |
| # Add similar questions context if available | |
| context = "" | |
| if st.get("similar_questions"): | |
| context = "\n\nSimilar questions for reference:\n" | |
| for sq in st["similar_questions"][:2]: | |
| context += f"Q: {sq['question']}\nA: {sq['answer']}\n" | |
| enhanced_query = f""" | |
| Question: {st["query"]} | |
| {context} | |
| Please provide a direct, accurate answer to this question. | |
| """ | |
| sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) | |
| res = llama8_llm.invoke([sys, HumanMessage(content=enhanced_query)]) | |
| answer = res.content.strip() | |
| if "FINAL ANSWER:" in answer: | |
| answer = answer.split("FINAL ANSWER:")[-1].strip() | |
| return {**st, | |
| "final_answer": answer, | |
| "reasoning": "Used Llama-3 8B with similar questions context", | |
| "perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} | |
| def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| """Process query with Llama-3 70B model.""" | |
| t0 = time.time() | |
| try: | |
| # Add similar questions context if available | |
| context = "" | |
| if st.get("similar_questions"): | |
| context = "\n\nSimilar questions for reference:\n" | |
| for sq in st["similar_questions"][:2]: | |
| context += f"Q: {sq['question']}\nA: {sq['answer']}\n" | |
| enhanced_query = f""" | |
| Question: {st["query"]} | |
| {context} | |
| Please provide a direct, accurate answer to this question. | |
| """ | |
| sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) | |
| res = llama70_llm.invoke([sys, HumanMessage(content=enhanced_query)]) | |
| answer = res.content.strip() | |
| if "FINAL ANSWER:" in answer: | |
| answer = answer.split("FINAL ANSWER:")[-1].strip() | |
| return {**st, | |
| "final_answer": answer, | |
| "reasoning": "Used Llama-3 70B for complex reasoning with context", | |
| "perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} | |
| def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| """Process query with DeepSeek model.""" | |
| t0 = time.time() | |
| try: | |
| # Add similar questions context if available | |
| context = "" | |
| if st.get("similar_questions"): | |
| context = "\n\nSimilar questions for reference:\n" | |
| for sq in st["similar_questions"][:2]: | |
| context += f"Q: {sq['question']}\nA: {sq['answer']}\n" | |
| enhanced_query = f""" | |
| Question: {st["query"]} | |
| {context} | |
| Please provide a direct, accurate answer to this question. | |
| """ | |
| sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) | |
| res = deepseek_llm.invoke([sys, HumanMessage(content=enhanced_query)]) | |
| answer = res.content.strip() | |
| if "FINAL ANSWER:" in answer: | |
| answer = answer.split("FINAL ANSWER:")[-1].strip() | |
| return {**st, | |
| "final_answer": answer, | |
| "reasoning": "Used DeepSeek for advanced reasoning and analysis", | |
| "perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} | |
| def search_enhanced_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
| """Process query with search enhancement.""" | |
| t0 = time.time() | |
| tools_used = [] | |
| try: | |
| # Determine search strategy | |
| query = st["query"] | |
| search_results = "" | |
| if any(keyword in query.lower() for keyword in ["wikipedia", "wiki"]): | |
| search_results = optimized_wiki_search.invoke({"query": query}) | |
| tools_used.append("wikipedia_search") | |
| else: | |
| search_results = optimized_web_search.invoke({"query": query}) | |
| tools_used.append("web_search") | |
| # Add similar questions context | |
| context = "" | |
| if st.get("similar_questions"): | |
| context = "\n\nSimilar questions for reference:\n" | |
| for sq in st["similar_questions"][:2]: | |
| context += f"Q: {sq['question']}\nA: {sq['answer']}\n" | |
| enhanced_query = f""" | |
| Original Question: {query} | |
| Search Results: | |
| {search_results} | |
| {context} | |
| Based on the search results and similar questions above, provide a direct answer to the original question. | |
| """ | |
| sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) | |
| res = llama70_llm.invoke([sys, HumanMessage(content=enhanced_query)]) | |
| answer = res.content.strip() | |
| if "FINAL ANSWER:" in answer: | |
| answer = answer.split("FINAL ANSWER:")[-1].strip() | |
| return {**st, | |
| "final_answer": answer, | |
| "tools_used": tools_used, | |
| "reasoning": "Used search enhancement with similar questions context", | |
| "perf": {"time": time.time() - t0, "prov": "Search-Enhanced-Llama70"}} | |
| except Exception as e: | |
| return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} | |
| # Build graph | |
| g = StateGraph(EnhancedAgentState) | |
| g.add_node("router", router) | |
| g.add_node("llama8", llama8_node) | |
| g.add_node("llama70", llama70_node) | |
| g.add_node("deepseek", deepseek_node) | |
| g.add_node("search_enhanced", search_enhanced_node) | |
| g.set_entry_point("router") | |
| g.add_conditional_edges("router", lambda s: s["agent_type"], { | |
| "llama8": "llama8", | |
| "llama70": "llama70", | |
| "deepseek": "deepseek", | |
| "search_enhanced": "search_enhanced" | |
| }) | |
| for node in ["llama8", "llama70", "deepseek", "search_enhanced"]: | |
| g.add_edge(node, END) | |
| return g.compile(checkpointer=MemorySaver()) | |
| def process_query(self, q: str) -> str: | |
| """Process a query through the enhanced multi-LLM system.""" | |
| state = { | |
| "messages": [HumanMessage(content=q)], | |
| "query": q, | |
| "agent_type": "", | |
| "final_answer": "", | |
| "perf": {}, | |
| "agno_resp": "", | |
| "tools_used": [], | |
| "reasoning": "", | |
| "similar_questions": [] | |
| } | |
| cfg = {"configurable": {"thread_id": f"enhanced_qa_{hash(q)}"}} | |
| try: | |
| out = self.graph.invoke(state, cfg) | |
| answer = out.get("final_answer", "").strip() | |
| # Ensure we don't return the question as the answer | |
| if answer == q or answer.startswith(q): | |
| return "Information not available" | |
| return answer if answer else "No answer generated" | |
| except Exception as e: | |
| return f"Error processing query: {e}" | |
| def load_metadata_from_jsonl(self, jsonl_file_path: str) -> int: | |
| """Load question metadata from JSONL file into vector database""" | |
| success_count = 0 | |
| try: | |
| with open(jsonl_file_path, 'r', encoding='utf-8') as file: | |
| for line_num, line in enumerate(file, 1): | |
| try: | |
| data = json.loads(line.strip()) | |
| if self.vector_db.insert_question_data(data): | |
| success_count += 1 | |
| if line_num % 10 == 0: | |
| print(f"Processed {line_num} records, {success_count} successful") | |
| except json.JSONDecodeError as e: | |
| print(f"JSON decode error on line {line_num}: {e}") | |
| except Exception as e: | |
| print(f"Error processing line {line_num}: {e}") | |
| except FileNotFoundError: | |
| print(f"File not found: {jsonl_file_path}") | |
| print(f"Loaded {success_count} questions into vector database") | |
| return success_count | |
| def build_graph(provider: str | None = None) -> StateGraph: | |
| """Build and return the graph for the enhanced agent system.""" | |
| return HybridLangGraphMultiLLMSystem(provider or "groq").graph | |
| if __name__ == "__main__": | |
| # Initialize and test the system | |
| system = HybridLangGraphMultiLLMSystem() | |
| # Load metadata if available | |
| if os.path.exists("metadata.jsonl"): | |
| system.load_metadata_from_jsonl("metadata.jsonl") | |
| # Test queries | |
| test_questions = [ | |
| "How many studio albums were published by Mercedes Sosa between 2000 and 2009?", | |
| "What is 25 multiplied by 17?", | |
| "Find information about artificial intelligence on Wikipedia" | |
| ] | |
| for question in test_questions: | |
| print(f"Question: {question}") | |
| answer = system.process_query(question) | |
| print(f"Answer: {answer}") | |
| print("-" * 50) | |