Spaces:
Build error
Build error
| # ========================= | |
| # IMPORTS | |
| # ========================= | |
| from annotated_types import doc | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langgraph.graph import StateGraph, END, START | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langgraph.graph import MessagesState | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from llm import get_llm | |
| from tools import * | |
| from supabase.client import create_client | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| # from langchain.tools.retriever import create_retriever_tool | |
| from langchain_core.tools import create_retriever_tool | |
| import os | |
| from dotenv import load_dotenv | |
| # ========================================================= | |
| # Load environment variables | |
| # ========================================================= | |
| load_dotenv() | |
| SUPABASE_URL = os.getenv("SUPABASE_URL") | |
| SUPABASE_KEY = os.getenv("SUPABASE_KEY") | |
| if not SUPABASE_URL: | |
| raise ValueError("Missing SUPABASE_URL") | |
| if not SUPABASE_KEY: | |
| raise ValueError("Missing SUPABASE_KEY") | |
| # ========================= | |
| # LLM SETUP | |
| # ========================= | |
| llm = get_llm("groq") | |
| # ====================================================== | |
| # EMBEDDINGS | |
| # ====================================================== | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # ====================================================== | |
| # SUPABASE | |
| # ====================================================== | |
| # SUPABASE_URL = os.getenv("SUPABASE_URL") | |
| # SUPABASE_KEY = os.getenv("SUPABASE_KEY") | |
| supabase = create_client( | |
| SUPABASE_URL, | |
| SUPABASE_KEY | |
| ) | |
| # ========================================================= | |
| # RETRIEVAL | |
| # ========================================================= | |
| def retrieve_documents(query: str, k: int = 5): | |
| # Generate embedding | |
| query_embedding = embeddings.embed_query(query) | |
| # Call Supabase RPC | |
| response = supabase.rpc( | |
| "match_research_tasks", | |
| { | |
| "query_embedding": query_embedding, | |
| "match_count": k | |
| } | |
| ).execute() | |
| docs = response.data if response.data else [] | |
| print("\n===== RETRIEVED DOCS =====") | |
| print(docs) | |
| return docs | |
| # ========================================================= | |
| # RETRIEVER NODE | |
| # ========================================================= | |
| def retriever_node(state: MessagesState): | |
| # Last user message | |
| user_question = state["messages"][-1].content.strip() | |
| print("\n===== USER QUESTION =====") | |
| print(user_question) | |
| # Retrieve similar tasks | |
| docs = retrieve_documents( | |
| user_question, | |
| k=5 | |
| ) | |
| # No docs | |
| if not docs: | |
| return { | |
| "messages": state["messages"] | |
| } | |
| # Similarity filtering | |
| filtered_docs = [ | |
| doc for doc in docs | |
| if doc["similarity"] >= 0.70 | |
| ] | |
| print("\n===== FILTERED DOCS =====") | |
| print(filtered_docs) | |
| # Nothing good enough | |
| if not filtered_docs: | |
| return { | |
| "messages": state["messages"] | |
| } | |
| # Build retrieval context | |
| context = "\n\n".join([ | |
| f""" | |
| Question: {doc['question']} | |
| Answer: {doc['final_answer']} | |
| Similarity: {doc['similarity']:.4f} | |
| """ | |
| for doc in filtered_docs | |
| ]) | |
| retrieval_message = SystemMessage( | |
| content=f""" | |
| You are given previously solved similar tasks. | |
| Use them ONLY as reference. | |
| Retrieved Examples: | |
| {context} | |
| """ | |
| ) | |
| # IMPORTANT: | |
| # retrieval message FIRST | |
| # then original user question | |
| return { | |
| "messages": [retrieval_message] + state["messages"] | |
| } | |
| # ========================================================= | |
| # ASSISTANT NODE | |
| # ========================================================= | |
| def assistant_node(state: MessagesState): | |
| messages = state["messages"] | |
| system_prompt = SystemMessage(content=""" | |
| You are a precise question-answering assistant. | |
| RULES: | |
| - Use retrieved examples if relevant | |
| - Prefer answers from highly similar examples | |
| - Do NOT hallucinate | |
| - Keep answers concise | |
| - Output ONLY the final answer | |
| """) | |
| final_messages = [system_prompt] + messages | |
| print("\n===== FINAL PROMPT TO LLM =====") | |
| for m in final_messages: | |
| print(f"\n[{m.type.upper()}]") | |
| print(m.content) | |
| response = llm.invoke(final_messages) | |
| return { | |
| "messages": [response] | |
| } | |
| # ========================================================= | |
| # BUILD GRAPH | |
| # ========================================================= | |
| graph = StateGraph(MessagesState) | |
| graph.add_node("retriever", retriever_node) | |
| graph.add_node("assistant", assistant_node) | |
| graph.add_edge(START, "retriever") | |
| graph.add_edge("retriever", "assistant") | |
| graph.add_edge("assistant", END) | |
| app = graph.compile() | |
| # ========================================================= | |
| # ASK FUNCTION | |
| # ========================================================= | |
| def ask_agent(question: str): | |
| result = app.invoke({ | |
| "messages": [ | |
| HumanMessage(content=question) | |
| ] | |
| }) | |
| final_answer = result["messages"][-1].content | |
| return final_answer | |
| # ========================================================= | |
| # TEST | |
| # ========================================================= | |
| if __name__ == "__main__": | |
| while True: | |
| q = input("\nAsk: ") | |
| if q.lower() in ["exit", "quit"]: | |
| break | |
| answer = ask_agent(q) | |
| print("\n===== FINAL ANSWER =====") | |
| print(answer) |