Prasanthkumar commited on
Commit
c72bd68
·
verified ·
1 Parent(s): a2868dc

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +54 -39
model.py CHANGED
@@ -1,8 +1,11 @@
 
 
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition
5
- from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
@@ -11,100 +14,115 @@ from langchain_community.vectorstores import SupabaseVectorStore
11
  from langchain_core.messages import SystemMessage, HumanMessage
12
  from langchain_core.tools import tool
13
  from langchain_tavily import TavilySearch
 
14
  from supabase.client import Client, create_client
15
 
16
  load_dotenv()
17
 
 
18
  url = os.getenv("SUPABASE_URL")
19
  key = os.getenv("SUPABASE_KEY")
 
20
 
21
- # Math Tools
22
  @tool
23
  def multiply(a: int, b: int) -> int:
 
24
  return a * b
25
 
26
  @tool
27
  def add(a: int, b: int) -> int:
 
28
  return a + b
29
 
30
  @tool
31
  def subtract(a: int, b: int) -> int:
 
32
  return a - b
33
 
34
  @tool
35
  def divide(a: int, b: int) -> float:
 
36
  if b == 0:
37
  raise ValueError("Cannot divide by zero.")
38
  return a / b
39
 
40
  @tool
41
  def modulus(a: int, b: int) -> int:
 
42
  return a % b
43
 
44
- # Search Tools
45
  @tool
46
  def wiki_search(query: str) -> str:
47
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
48
- return "\n---\n".join([doc.page_content for doc in search_docs])
 
49
 
50
  @tool
51
  def web_search(query: str) -> str:
52
- tavily = TavilySearch(k=3)
53
- results = tavily.invoke(query)
54
- return "\n---\n".join([doc.page_content for doc in results])
55
 
56
  @tool
57
  def arvix_search(query: str) -> str:
58
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
59
- return "\n---\n".join([doc.page_content[:1000] for doc in search_docs])
 
60
 
61
  # Load system prompt
62
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
63
  system_prompt = f.read()
 
64
  sys_msg = SystemMessage(content=system_prompt)
65
 
66
- # Vector store setup
67
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
68
- supabase: Client = create_client(url, key)
69
  vector_store = SupabaseVectorStore(
70
  client=supabase,
71
  embedding=embeddings,
72
  table_name="documents",
73
  query_name="match_documents_langchain",
74
  )
75
- retriever_tool = tool(name="Question Search", description="Retrieve similar questions from vector DB")(vector_store.as_retriever().invoke)
76
 
77
- tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search, retriever_tool]
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def build_graph(provider: str = "groq"):
80
  if provider == "google":
81
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
82
  elif provider == "groq":
83
- api_key = os.getenv('GROQ_API')
84
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=api_key)
85
  elif provider == "huggingface":
86
- llm = ChatHuggingFace(
87
- llm=HuggingFaceEndpoint(
88
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
89
- temperature=0,
90
- ),
91
- )
92
  else:
93
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
94
 
95
  llm_with_tools = llm.bind_tools(tools)
96
 
97
  def assistant(state: MessagesState):
98
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
99
 
100
  def retriever(state: MessagesState):
101
- similar_docs = vector_store.similarity_search(state["messages"][0].content)
102
- if not similar_docs:
103
  return {"messages": [sys_msg] + state["messages"]}
104
- example_msg = HumanMessage(
105
- content=f"Here is a related example to help: \n\n{similar_docs[0].page_content}"
106
- )
107
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
108
 
109
  builder = StateGraph(MessagesState)
110
  builder.add_node("retriever", retriever)
@@ -117,10 +135,7 @@ def build_graph(provider: str = "groq"):
117
 
118
  return builder.compile()
119
 
120
- if __name__ == "__main__":
121
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
122
- graph = build_graph(provider="groq")
123
- messages = [HumanMessage(content=question)]
124
- output = graph.invoke({"messages": messages})
125
- for m in output["messages"]:
126
- m.pretty_print()
 
1
+ # ============================
2
+ # model.py
3
+ # ============================
4
+
5
  import os
6
  from dotenv import load_dotenv
7
  from langgraph.graph import START, StateGraph, MessagesState
8
+ from langgraph.prebuilt import tools_condition, ToolNode
 
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_groq import ChatGroq
11
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain_tavily import TavilySearch
17
+ from langchain.tools.retriever import create_retriever_tool
18
  from supabase.client import Client, create_client
19
 
20
  load_dotenv()
21
 
22
+ # Setup Supabase
23
  url = os.getenv("SUPABASE_URL")
24
  key = os.getenv("SUPABASE_KEY")
25
+ supabase: Client = create_client(url, key)
26
 
27
+ # Tools
28
  @tool
29
  def multiply(a: int, b: int) -> int:
30
+ """Multiply two numbers and return the result."""
31
  return a * b
32
 
33
  @tool
34
  def add(a: int, b: int) -> int:
35
+ """Add two numbers and return the result."""
36
  return a + b
37
 
38
  @tool
39
  def subtract(a: int, b: int) -> int:
40
+ """Subtract second number from first and return the result."""
41
  return a - b
42
 
43
  @tool
44
  def divide(a: int, b: int) -> float:
45
+ """Divide first number by second and return the result."""
46
  if b == 0:
47
  raise ValueError("Cannot divide by zero.")
48
  return a / b
49
 
50
  @tool
51
  def modulus(a: int, b: int) -> int:
52
+ """Return the modulus (remainder) of two numbers."""
53
  return a % b
54
 
 
55
  @tool
56
  def wiki_search(query: str) -> str:
57
+ """Search Wikipedia and return 2 results."""
58
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
59
+ return "\n\n---\n\n".join(doc.page_content for doc in docs)
60
 
61
  @tool
62
  def web_search(query: str) -> str:
63
+ """Search the web using Tavily and return 3 results."""
64
+ docs = TavilySearch(max_results=3).invoke(query)
65
+ return "\n\n---\n\n".join(doc.page_content for doc in docs)
66
 
67
  @tool
68
  def arvix_search(query: str) -> str:
69
+ """Search Arxiv for academic papers and return 3 results."""
70
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
71
+ return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
72
 
73
  # Load system prompt
74
+ with open("system_prompt.txt", "r") as f:
75
  system_prompt = f.read()
76
+
77
  sys_msg = SystemMessage(content=system_prompt)
78
 
79
+ # Vector search setup
80
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
81
  vector_store = SupabaseVectorStore(
82
  client=supabase,
83
  embedding=embeddings,
84
  table_name="documents",
85
  query_name="match_documents_langchain",
86
  )
 
87
 
88
+ retriever_tool = create_retriever_tool(
89
+ retriever=vector_store.as_retriever(),
90
+ name="Question Search",
91
+ description="Retrieve similar questions from vector DB.",
92
+ )
93
+
94
+ # Tools list
95
+ tools = [
96
+ multiply, add, subtract, divide, modulus,
97
+ wiki_search, web_search, arvix_search,
98
+ retriever_tool,
99
+ ]
100
+
101
+ # Build LangGraph
102
 
103
  def build_graph(provider: str = "groq"):
104
  if provider == "google":
105
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
106
  elif provider == "groq":
107
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0, api_key=os.getenv("GROQ_API"))
 
108
  elif provider == "huggingface":
109
+ llm = ChatHuggingFace(llm=HuggingFaceEndpoint(
110
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
111
+ temperature=0))
 
 
 
112
  else:
113
+ raise ValueError("Invalid provider")
114
 
115
  llm_with_tools = llm.bind_tools(tools)
116
 
117
  def assistant(state: MessagesState):
118
+ return {"messages": [llm_with_tools.invoke(state["messages\])]}
119
 
120
  def retriever(state: MessagesState):
121
+ docs = vector_store.similarity_search(state["messages"][0].content)
122
+ if not docs:
123
  return {"messages": [sys_msg] + state["messages"]}
124
+ similar_msg = HumanMessage(content=f"Reference: {docs[0].page_content}")
125
+ return {"messages": [sys_msg] + state["messages"] + [similar_msg]}
 
 
126
 
127
  builder = StateGraph(MessagesState)
128
  builder.add_node("retriever", retriever)
 
135
 
136
  return builder.compile()
137
 
138
+
139
+ # ============================
140
+ # Save this as model.py and let me know when you want full app.py regenerated to match
141
+ # ============================