File size: 4,653 Bytes
c72bd68 4852cd1 c72bd68 4852cd1 166ba87 4852cd1 166ba87 c72bd68 4852cd1 166ba87 4852cd1 c72bd68 e4495f7 c72bd68 dba100b c72bd68 4852cd1 c72bd68 4852cd1 c72bd68 4852cd1 c72bd68 4852cd1 166ba87 c72bd68 4852cd1 c72bd68 4852cd1 c72bd68 4852cd1 c72bd68 4852cd1 c72bd68 4852cd1 166ba87 c72bd68 4852cd1 c72bd68 4852cd1 c72bd68 166ba87 4852cd1 166ba87 4852cd1 c72bd68 4852cd1 c72bd68 4852cd1 c72bd68 4852cd1 c72bd68 166ba87 4852cd1 280e958 b9b310b 4852cd1 c72bd68 f6a6958 c72bd68 b9b310b 4852cd1 166ba87 4852cd1 c72bd68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# ============================
# model.py
# ============================
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_tavily import TavilySearch
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client
load_dotenv()
# Setup Supabase
url = os.getenv("SUPABASE_URL")
key = os.getenv("SUPABASE_KEY")
supabase: Client = create_client(url, key)
# Tools
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers and return the result."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers and return the result."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract second number from first and return the result."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide first number by second and return the result."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Return the modulus (remainder) of two numbers."""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia and return 2 results."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(doc.page_content for doc in docs)
@tool
def web_search(query: str) -> str:
"""Search the web using Tavily and return 3 results."""
docs = TavilySearch(max_results=3).invoke(query)
return "\n\n---\n\n".join(doc.page_content for doc in docs)
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for academic papers and return 3 results."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
# Load system prompt
with open("system_prompt.txt", "r") as f:
system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)
# Vector search setup
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents",
query_name="match_documents_langchain",
)
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="Retrieve similar questions from vector DB.",
)
# Tools list
tools = [
multiply, add, subtract, divide, modulus,
wiki_search, web_search, arvix_search,
retriever_tool,
]
# Build LangGraph
def build_graph(provider: str = "groq"):
if provider == "google":
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
elif provider == "groq":
llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=os.getenv("GROQ_API"))
elif provider == "huggingface":
llm = ChatHuggingFace(llm=HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0))
else:
raise ValueError("Invalid provider")
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
def retriever(state: MessagesState):
docs = vector_store.similarity_search(state["messages"][0].content)
if not docs:
return {"messages": [sys_msg] + state["messages"]}
similar_msg = HumanMessage(content=f"Reference: {docs[0].page_content}")
return {"messages": [sys_msg] + state["messages"] + [similar_msg]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
# ============================
# Save this as model.py and let me know when you want full app.py regenerated to match
# ============================
|