# ========================= # IMPORTS # ========================= from annotated_types import doc from langchain_huggingface import HuggingFaceEmbeddings from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode, tools_condition from langgraph.graph import MessagesState from langchain_core.messages import HumanMessage, SystemMessage from llm import get_llm from tools import * from supabase.client import create_client from langchain_community.vectorstores import SupabaseVectorStore # from langchain.tools.retriever import create_retriever_tool from langchain_core.tools import create_retriever_tool import os from dotenv import load_dotenv # ========================================================= # Load environment variables # ========================================================= load_dotenv() SUPABASE_URL = os.getenv("SUPABASE_URL") SUPABASE_KEY = os.getenv("SUPABASE_KEY") if not SUPABASE_URL: raise ValueError("Missing SUPABASE_URL") if not SUPABASE_KEY: raise ValueError("Missing SUPABASE_KEY") # ========================= # LLM SETUP # ========================= llm = get_llm("groq") # ====================================================== # EMBEDDINGS # ====================================================== embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2" ) # ====================================================== # SUPABASE # ====================================================== # SUPABASE_URL = os.getenv("SUPABASE_URL") # SUPABASE_KEY = os.getenv("SUPABASE_KEY") supabase = create_client( SUPABASE_URL, SUPABASE_KEY ) # ========================================================= # RETRIEVAL # ========================================================= def retrieve_documents(query: str, k: int = 5): # Generate embedding query_embedding = embeddings.embed_query(query) # Call Supabase RPC response = supabase.rpc( "match_research_tasks", { "query_embedding": query_embedding, "match_count": k } ).execute() docs = response.data if response.data else [] print("\n===== RETRIEVED DOCS =====") print(docs) return docs # ========================================================= # RETRIEVER NODE # ========================================================= def retriever_node(state: MessagesState): # Last user message user_question = state["messages"][-1].content.strip() print("\n===== USER QUESTION =====") print(user_question) # Retrieve similar tasks docs = retrieve_documents( user_question, k=5 ) # No docs if not docs: return { "messages": state["messages"] } # Similarity filtering filtered_docs = [ doc for doc in docs if doc["similarity"] >= 0.70 ] print("\n===== FILTERED DOCS =====") print(filtered_docs) # Nothing good enough if not filtered_docs: return { "messages": state["messages"] } # Build retrieval context context = "\n\n".join([ f""" Question: {doc['question']} Answer: {doc['final_answer']} Similarity: {doc['similarity']:.4f} """ for doc in filtered_docs ]) retrieval_message = SystemMessage( content=f""" You are given previously solved similar tasks. Use them ONLY as reference. Retrieved Examples: {context} """ ) # IMPORTANT: # retrieval message FIRST # then original user question return { "messages": [retrieval_message] + state["messages"] } # ========================================================= # ASSISTANT NODE # ========================================================= def assistant_node(state: MessagesState): messages = state["messages"] system_prompt = SystemMessage(content=""" You are a precise question-answering assistant. RULES: - Use retrieved examples if relevant - Prefer answers from highly similar examples - Do NOT hallucinate - Keep answers concise - Output ONLY the final answer """) final_messages = [system_prompt] + messages print("\n===== FINAL PROMPT TO LLM =====") for m in final_messages: print(f"\n[{m.type.upper()}]") print(m.content) response = llm.invoke(final_messages) return { "messages": [response] } # ========================================================= # BUILD GRAPH # ========================================================= graph = StateGraph(MessagesState) graph.add_node("retriever", retriever_node) graph.add_node("assistant", assistant_node) graph.add_edge(START, "retriever") graph.add_edge("retriever", "assistant") graph.add_edge("assistant", END) app = graph.compile() # ========================================================= # ASK FUNCTION # ========================================================= def ask_agent(question: str): result = app.invoke({ "messages": [ HumanMessage(content=question) ] }) final_answer = result["messages"][-1].content return final_answer # ========================================================= # TEST # ========================================================= if __name__ == "__main__": while True: q = input("\nAsk: ") if q.lower() in ["exit", "quit"]: break answer = ask_agent(q) print("\n===== FINAL ANSWER =====") print(answer)