File size: 4,653 Bytes
c72bd68
 
 
 
4852cd1
 
 
c72bd68
4852cd1
 
 
166ba87
4852cd1
 
 
166ba87
c72bd68
4852cd1
166ba87
4852cd1
 
c72bd68
e4495f7
 
c72bd68
dba100b
c72bd68
4852cd1
 
c72bd68
4852cd1
 
 
 
c72bd68
4852cd1
 
 
 
c72bd68
4852cd1
 
 
166ba87
c72bd68
4852cd1
 
 
 
 
 
c72bd68
4852cd1
 
 
 
c72bd68
 
 
4852cd1
 
 
c72bd68
 
 
4852cd1
 
 
c72bd68
 
 
4852cd1
166ba87
c72bd68
4852cd1
c72bd68
4852cd1
 
c72bd68
166ba87
4852cd1
 
166ba87
4852cd1
 
 
 
c72bd68
 
 
 
 
 
 
 
 
 
 
 
 
 
4852cd1
 
 
 
 
c72bd68
4852cd1
c72bd68
 
 
4852cd1
c72bd68
166ba87
4852cd1
 
 
280e958
b9b310b
4852cd1
c72bd68
 
f6a6958
c72bd68
 
b9b310b
4852cd1
 
 
 
 
 
166ba87
4852cd1
 
 
 
c72bd68
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# ============================
# model.py
# ============================

import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_tavily import TavilySearch
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client

load_dotenv()

# Setup Supabase
url = os.getenv("SUPABASE_URL")
key = os.getenv("SUPABASE_KEY")
supabase: Client = create_client(url, key)

# Tools
@tool
def multiply(a: int, b: int) -> int:
    """Multiply two numbers and return the result."""
    return a * b

@tool
def add(a: int, b: int) -> int:
    """Add two numbers and return the result."""
    return a + b

@tool
def subtract(a: int, b: int) -> int:
    """Subtract second number from first and return the result."""
    return a - b

@tool
def divide(a: int, b: int) -> float:
    """Divide first number by second and return the result."""
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b

@tool
def modulus(a: int, b: int) -> int:
    """Return the modulus (remainder) of two numbers."""
    return a % b

@tool
def wiki_search(query: str) -> str:
    """Search Wikipedia and return 2 results."""
    docs = WikipediaLoader(query=query, load_max_docs=2).load()
    return "\n\n---\n\n".join(doc.page_content for doc in docs)

@tool
def web_search(query: str) -> str:
    """Search the web using Tavily and return 3 results."""
    docs = TavilySearch(max_results=3).invoke(query)
    return "\n\n---\n\n".join(doc.page_content for doc in docs)

@tool
def arvix_search(query: str) -> str:
    """Search Arxiv for academic papers and return 3 results."""
    docs = ArxivLoader(query=query, load_max_docs=3).load()
    return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)

# Load system prompt
with open("system_prompt.txt", "r") as f:
    system_prompt = f.read()

sys_msg = SystemMessage(content=system_prompt)

# Vector search setup
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding=embeddings,
    table_name="documents",
    query_name="match_documents_langchain",
)

retriever_tool = create_retriever_tool(
    retriever=vector_store.as_retriever(),
    name="Question Search",
    description="Retrieve similar questions from vector DB.",
)

# Tools list
tools = [
    multiply, add, subtract, divide, modulus,
    wiki_search, web_search, arvix_search,
    retriever_tool,
]

# Build LangGraph

def build_graph(provider: str = "groq"):
    if provider == "google":
        llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
    elif provider == "groq":
        llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=os.getenv("GROQ_API"))
    elif provider == "huggingface":
        llm = ChatHuggingFace(llm=HuggingFaceEndpoint(
            url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
            temperature=0))
    else:
        raise ValueError("Invalid provider")

    llm_with_tools = llm.bind_tools(tools)

    def assistant(state: MessagesState):
        return {"messages": [llm_with_tools.invoke(state["messages"])]}

    def retriever(state: MessagesState):
        docs = vector_store.similarity_search(state["messages"][0].content)
        if not docs:
            return {"messages": [sys_msg] + state["messages"]}
        similar_msg = HumanMessage(content=f"Reference: {docs[0].page_content}")
        return {"messages": [sys_msg] + state["messages"] + [similar_msg]}

    builder = StateGraph(MessagesState)
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges("assistant", tools_condition)
    builder.add_edge("tools", "assistant")

    return builder.compile()


# ============================
# Save this as model.py and let me know when you want full app.py regenerated to match
# ============================