Final-Agent-Course / agent.py
Chitranshu-9's picture
Final agent submission
84c9c5b
Raw
History Blame Contribute Delete
5.52 kB
# =========================
# IMPORTS
# =========================
from annotated_types import doc
from langchain_huggingface import HuggingFaceEmbeddings
from langgraph.graph import StateGraph, END, START
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import MessagesState
from langchain_core.messages import HumanMessage, SystemMessage
from llm import get_llm
from tools import *
from supabase.client import create_client
from langchain_community.vectorstores import SupabaseVectorStore
# from langchain.tools.retriever import create_retriever_tool
from langchain_core.tools import create_retriever_tool
import os
from dotenv import load_dotenv
# =========================================================
# Load environment variables
# =========================================================
load_dotenv()
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_KEY")
if not SUPABASE_URL:
raise ValueError("Missing SUPABASE_URL")
if not SUPABASE_KEY:
raise ValueError("Missing SUPABASE_KEY")
# =========================
# LLM SETUP
# =========================
llm = get_llm("groq")
# ======================================================
# EMBEDDINGS
# ======================================================
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# ======================================================
# SUPABASE
# ======================================================
# SUPABASE_URL = os.getenv("SUPABASE_URL")
# SUPABASE_KEY = os.getenv("SUPABASE_KEY")
supabase = create_client(
SUPABASE_URL,
SUPABASE_KEY
)
# =========================================================
# RETRIEVAL
# =========================================================
def retrieve_documents(query: str, k: int = 5):
# Generate embedding
query_embedding = embeddings.embed_query(query)
# Call Supabase RPC
response = supabase.rpc(
"match_research_tasks",
{
"query_embedding": query_embedding,
"match_count": k
}
).execute()
docs = response.data if response.data else []
print("\n===== RETRIEVED DOCS =====")
print(docs)
return docs
# =========================================================
# RETRIEVER NODE
# =========================================================
def retriever_node(state: MessagesState):
# Last user message
user_question = state["messages"][-1].content.strip()
print("\n===== USER QUESTION =====")
print(user_question)
# Retrieve similar tasks
docs = retrieve_documents(
user_question,
k=5
)
# No docs
if not docs:
return {
"messages": state["messages"]
}
# Similarity filtering
filtered_docs = [
doc for doc in docs
if doc["similarity"] >= 0.70
]
print("\n===== FILTERED DOCS =====")
print(filtered_docs)
# Nothing good enough
if not filtered_docs:
return {
"messages": state["messages"]
}
# Build retrieval context
context = "\n\n".join([
f"""
Question: {doc['question']}
Answer: {doc['final_answer']}
Similarity: {doc['similarity']:.4f}
"""
for doc in filtered_docs
])
retrieval_message = SystemMessage(
content=f"""
You are given previously solved similar tasks.
Use them ONLY as reference.
Retrieved Examples:
{context}
"""
)
# IMPORTANT:
# retrieval message FIRST
# then original user question
return {
"messages": [retrieval_message] + state["messages"]
}
# =========================================================
# ASSISTANT NODE
# =========================================================
def assistant_node(state: MessagesState):
messages = state["messages"]
system_prompt = SystemMessage(content="""
You are a precise question-answering assistant.
RULES:
- Use retrieved examples if relevant
- Prefer answers from highly similar examples
- Do NOT hallucinate
- Keep answers concise
- Output ONLY the final answer
""")
final_messages = [system_prompt] + messages
print("\n===== FINAL PROMPT TO LLM =====")
for m in final_messages:
print(f"\n[{m.type.upper()}]")
print(m.content)
response = llm.invoke(final_messages)
return {
"messages": [response]
}
# =========================================================
# BUILD GRAPH
# =========================================================
graph = StateGraph(MessagesState)
graph.add_node("retriever", retriever_node)
graph.add_node("assistant", assistant_node)
graph.add_edge(START, "retriever")
graph.add_edge("retriever", "assistant")
graph.add_edge("assistant", END)
app = graph.compile()
# =========================================================
# ASK FUNCTION
# =========================================================
def ask_agent(question: str):
result = app.invoke({
"messages": [
HumanMessage(content=question)
]
})
final_answer = result["messages"][-1].content
return final_answer
# =========================================================
# TEST
# =========================================================
if __name__ == "__main__":
while True:
q = input("\nAsk: ")
if q.lower() in ["exit", "quit"]:
break
answer = ask_agent(q)
print("\n===== FINAL ANSWER =====")
print(answer)