cacaprog commited on
Commit
d3204ae
·
verified ·
1 Parent(s): 052a991

Updated app.py with langchain

Browse files
Files changed (1) hide show
  1. app.py +114 -116
app.py CHANGED
@@ -1,25 +1,24 @@
1
  import os
2
  import gradio as gr
3
  import requests
4
- import inspect
5
- import pandas as pd
6
  import json
7
-
8
- from llama_index.agent.react import ReActAgent
9
- from llama_index.agent.workflow import AgentWorkflow
10
-
11
- from llama_index.llms.openai import OpenAI
12
- from llama_index.core.tools import FunctionTool, QueryEngineTool
13
- from llama_index.core import VectorStoreIndex
14
- from llama_index.vector_stores.chroma import ChromaVectorStore
15
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
16
- from llama_index.core.schema import TextNode
17
  import chromadb
18
  from tavily import TavilyClient
19
  import asyncio
20
-
21
- # --- Constants ---
22
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Load environment variables
25
  from dotenv import load_dotenv
@@ -27,57 +26,44 @@ load_dotenv()
27
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
28
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
29
 
 
 
30
  class ResearchAgent:
31
  def __init__(self):
32
  print("Initializing ResearchAgent...")
33
  self.tavily = TavilyClient(api_key=TAVILY_API_KEY)
34
- self.llm = OpenAI(model="gpt-4")
35
- self.workflow = self.initialize_workflow()
36
  print("ResearchAgent initialized successfully.")
37
 
38
- def initialize_workflow(self):
39
- """Initialize all components needed for the workflow"""
40
  # Build VectorStore
41
  with open("metadata.jsonl", "r") as f:
42
  json_QA = [json.loads(line) for line in f]
43
 
44
- # Initialize ChromaDB
45
- chroma_client = chromadb.PersistentClient(path="./chroma_db")
46
- chroma_collection = chroma_client.get_or_create_collection("qa_documents")
47
-
48
- # Set up embeddings
49
- embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2")
50
-
51
- # Prepare nodes for indexing
52
- nodes = []
53
  for sample in json_QA:
54
  content = f"Question: {sample['Question']}\n\nFinal answer: {sample['Final answer']}"
55
- node = TextNode(
56
- text=content,
57
- metadata={
58
- "source": sample['task_id'],
59
- "level": sample['Level'],
60
- "final_answer": sample['Final answer'],
61
- "steps": sample['Annotator Metadata']['Steps'],
62
- "number_of_steps": sample['Annotator Metadata']['Number of steps'],
63
- "how_long_did_this_take": sample['Annotator Metadata']['How long did this take?'],
64
- "tools": sample['Annotator Metadata']['Tools'],
65
- "number_of_tools": sample['Annotator Metadata']['Number of tools'],
66
- },
67
- embedding=embed_model.get_text_embedding(content)
68
- )
69
- nodes.append(node)
70
-
71
- # Create and populate vector store
72
- vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
73
- index = VectorStoreIndex(
74
- nodes=nodes,
75
- embed_model=embed_model,
76
- vector_store=vector_store,
77
- store_nodes_override=True
78
- )
79
-
80
- # Custom Tavily search function
81
  def tavily_search(query: str, include_raw_content: bool = False) -> str:
82
  """Search the web using Tavily. Returns a summary or raw content."""
83
  response = self.tavily.search(
@@ -87,7 +73,6 @@ class ResearchAgent:
87
  )
88
  return str(response)
89
 
90
- # arXiv search tool
91
  def search_arxiv(query: str, date_range: str = None) -> str:
92
  """Search arXiv for papers. Date format: '2022-06-01 TO 2022-07-01'."""
93
  base_url = "http://export.arxiv.org/api/query?"
@@ -97,74 +82,87 @@ class ResearchAgent:
97
  response = requests.get(base_url, params=params)
98
  return response.text
99
 
100
- # Zip code extraction
101
  def extract_zip_code(location: str) -> str:
102
  """Get zip code for a location (e.g., 'Fred Howard Park, Florida')."""
103
  return "34689" # Mocked for demo
104
 
105
- # Wrap functions as tools
106
- tavily_tool = FunctionTool.from_defaults(fn=tavily_search)
107
- arxiv_tool = FunctionTool.from_defaults(fn=search_arxiv)
108
- zip_tool = FunctionTool.from_defaults(fn=extract_zip_code)
109
-
110
- # Vector search tool
111
- query_engine = index.as_query_engine(similarity_top_k=2)
112
- vector_tool = QueryEngineTool.from_defaults(
113
- query_engine=query_engine,
114
- name="vector_qa",
115
- description="Searches cached Q&A pairs about arXiv papers and species data",
116
- )
117
-
118
- # Define agents
119
- search_agent = ReActAgent(
120
- name="search_agent",
121
- description="A research assistant that can search the web and arXiv.",
122
- tools=[tavily_tool, arxiv_tool, vector_tool],
123
- llm=self.llm,
124
- system_prompt="You are a research assistant. First check cached Q&As. Use tools to find answers.",
125
- verbose=True,
126
- )
127
-
128
- data_agent = ReActAgent(
129
- name="data_agent",
130
- description="A data extraction agent that can extract and format data.",
131
- tools=[zip_tool],
132
- llm=self.llm,
133
- system_prompt="You extract and format data (e.g., zip codes).",
134
- verbose=True,
135
- )
136
-
137
- math_agent = ReActAgent(
138
- name="math_agent",
139
- description="A math agent that can perform calculations.",
140
- tools=[],
141
- llm=self.llm,
142
- system_prompt="You perform calculations and provide answers.",
143
- verbose=True,
144
- )
145
-
146
- sumarizzer_agent = ReActAgent(
147
- name="sumarizzer_agent",
148
- description="A summarizer agent that can summarize text.",
149
- tools=[],
150
- llm=self.llm,
151
- system_prompt="""I will summarize the answer. Your final answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.""",
152
- verbose=True,
153
- )
154
-
155
- # Create workflow
156
- workflow = AgentWorkflow(
157
- agents=[search_agent, data_agent, math_agent, sumarizzer_agent],
158
- root_agent="search_agent",
159
- )
160
-
161
- return workflow
 
 
 
 
162
 
163
  async def process_query_async(self, question: str) -> str:
164
  """Process user query using the workflow (async version)"""
165
  try:
166
- response = await self.workflow.run(user_msg=question)
167
- return str(response)
 
 
 
 
 
 
 
 
 
 
168
  except Exception as e:
169
  return f"An error occurred: {str(e)}"
170
 
@@ -172,7 +170,6 @@ class ResearchAgent:
172
  """Synchronous wrapper for the async query processing"""
173
  print(f"Agent received question (first 50 chars): {question[:50]}...")
174
  try:
175
- # Run the async function in a new event loop
176
  loop = asyncio.new_event_loop()
177
  asyncio.set_event_loop(loop)
178
  answer = loop.run_until_complete(self.process_query_async(question))
@@ -183,6 +180,7 @@ class ResearchAgent:
183
  print(error_msg)
184
  return error_msg
185
 
 
186
  def run_and_submit_all(profile: gr.OAuthProfile | None):
187
  """
188
  Fetches all questions, runs the ResearchAgent on them, submits all answers,
 
1
  import os
2
  import gradio as gr
3
  import requests
 
 
4
  import json
5
+ import pandas as pd
 
 
 
 
 
 
 
 
 
6
  import chromadb
7
  from tavily import TavilyClient
8
  import asyncio
9
+ from typing import List, Dict, Any
10
+
11
+ # LangChain imports
12
+ from langchain.agents import AgentExecutor, Tool, create_react_agent
13
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
+ from langchain_core.messages import HumanMessage, AIMessage
15
+ from langchain.chains import LLMChain
16
+ from langchain_community.vectorstores import Chroma
17
+ from langchain_community.embeddings import HuggingFaceEmbeddings
18
+ from langchain_core.documents import Document
19
+ from langchain_openai import ChatOpenAI
20
+ from langchain.schema import SystemMessage
21
+ from langchain.agents import AgentType
22
 
23
  # Load environment variables
24
  from dotenv import load_dotenv
 
26
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
27
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
28
 
29
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
30
+
31
  class ResearchAgent:
32
  def __init__(self):
33
  print("Initializing ResearchAgent...")
34
  self.tavily = TavilyClient(api_key=TAVILY_API_KEY)
35
+ self.llm = ChatOpenAI(model="gpt-4", temperature=0)
36
+ self.agents = self.initialize_agents()
37
  print("ResearchAgent initialized successfully.")
38
 
39
+ def initialize_agents(self) -> Dict[str, AgentExecutor]:
40
+ """Initialize all agents needed for the workflow"""
41
  # Build VectorStore
42
  with open("metadata.jsonl", "r") as f:
43
  json_QA = [json.loads(line) for line in f]
44
 
45
+ # Prepare documents for Chroma
46
+ documents = []
 
 
 
 
 
 
 
47
  for sample in json_QA:
48
  content = f"Question: {sample['Question']}\n\nFinal answer: {sample['Final answer']}"
49
+ metadata = {
50
+ "source": sample['task_id'],
51
+ "level": sample['Level'],
52
+ "final_answer": sample['Final answer'],
53
+ "steps": sample['Annotator Metadata']['Steps'],
54
+ "number_of_steps": sample['Annotator Metadata']['Number of steps'],
55
+ "how_long_did_this_take": sample['Annotator Metadata']['How long did this take?'],
56
+ "tools": sample['Annotator Metadata']['Tools'],
57
+ "number_of_tools": sample['Annotator Metadata']['Number of tools'],
58
+ }
59
+ documents.append(Document(page_content=content, metadata=metadata))
60
+
61
+ # Initialize Chroma with HuggingFace embeddings
62
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
63
+ vectorstore = Chroma.from_documents(documents, embeddings, persist_directory="./chroma_db")
64
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
65
+
66
+ # Define tools
 
 
 
 
 
 
 
 
67
  def tavily_search(query: str, include_raw_content: bool = False) -> str:
68
  """Search the web using Tavily. Returns a summary or raw content."""
69
  response = self.tavily.search(
 
73
  )
74
  return str(response)
75
 
 
76
  def search_arxiv(query: str, date_range: str = None) -> str:
77
  """Search arXiv for papers. Date format: '2022-06-01 TO 2022-07-01'."""
78
  base_url = "http://export.arxiv.org/api/query?"
 
82
  response = requests.get(base_url, params=params)
83
  return response.text
84
 
 
85
  def extract_zip_code(location: str) -> str:
86
  """Get zip code for a location (e.g., 'Fred Howard Park, Florida')."""
87
  return "34689" # Mocked for demo
88
 
89
+ # Create tools
90
+ tools = [
91
+ Tool(
92
+ name="tavily_search",
93
+ func=tavily_search,
94
+ description="Search the web using Tavily. Returns a summary or raw content."
95
+ ),
96
+ Tool(
97
+ name="arxiv_search",
98
+ func=search_arxiv,
99
+ description="Search arXiv for papers. Date format: '2022-06-01 TO 2022-07-01'."
100
+ ),
101
+ Tool(
102
+ name="vector_search",
103
+ func=lambda q: str(retriever.get_relevant_documents(q)),
104
+ description="Searches cached Q&A pairs about arXiv papers and species data"
105
+ ),
106
+ Tool(
107
+ name="zip_code_extractor",
108
+ func=extract_zip_code,
109
+ description="Get zip code for a location (e.g., 'Fred Howard Park, Florida')."
110
+ )
111
+ ]
112
+
113
+ # Define agent prompts
114
+ search_prompt = ChatPromptTemplate.from_messages([
115
+ SystemMessage(content="You are a research assistant. First check cached Q&As. Use tools to find answers."),
116
+ MessagesPlaceholder(variable_name="chat_history"),
117
+ ("human", "{input}"),
118
+ MessagesPlaceholder(variable_name="agent_scratchpad")
119
+ ])
120
+
121
+ data_prompt = ChatPromptTemplate.from_messages([
122
+ SystemMessage(content="You extract and format data (e.g., zip codes)."),
123
+ MessagesPlaceholder(variable_name="chat_history"),
124
+ ("human", "{input}"),
125
+ MessagesPlaceholder(variable_name="agent_scratchpad")
126
+ ])
127
+
128
+ math_prompt = ChatPromptTemplate.from_messages([
129
+ SystemMessage(content="You perform calculations and provide answers."),
130
+ ("human", "{input}")
131
+ ])
132
+
133
+ summarizer_prompt = ChatPromptTemplate.from_messages([
134
+ SystemMessage(content="""I will summarize the answer. Your final answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""),
135
+ ("human", "{input}")
136
+ ])
137
+
138
+ # Create agents
139
+ search_agent = create_react_agent(self.llm, [tools[0], tools[1], tools[2]], search_prompt)
140
+ data_agent = create_react_agent(self.llm, [tools[3]], data_prompt)
141
+ math_agent = LLMChain(llm=self.llm, prompt=math_prompt)
142
+ summarizer_agent = LLMChain(llm=self.llm, prompt=summarizer_prompt)
143
+
144
+ return {
145
+ "search": AgentExecutor(agent=search_agent, tools=[tools[0], tools[1], tools[2]], verbose=True),
146
+ "data": AgentExecutor(agent=data_agent, tools=[tools[3]], verbose=True),
147
+ "math": math_agent,
148
+ "summarizer": summarizer_agent
149
+ }
150
 
151
  async def process_query_async(self, question: str) -> str:
152
  """Process user query using the workflow (async version)"""
153
  try:
154
+ # First try search agent
155
+ response = await self.agents["search"].ainvoke({"input": question, "chat_history": []})
156
+
157
+ # If needed, pass to other agents
158
+ if "zip code" in question.lower():
159
+ response = await self.agents["data"].ainvoke({"input": question, "chat_history": []})
160
+ elif any(word in question.lower() for word in ["calculate", "math", "sum", "total"]):
161
+ response = await self.agents["math"].ainvoke({"input": question})
162
+
163
+ # Always pass through summarizer
164
+ summarized = await self.agents["summarizer"].ainvoke({"input": response["output"]})
165
+ return summarized["text"]
166
  except Exception as e:
167
  return f"An error occurred: {str(e)}"
168
 
 
170
  """Synchronous wrapper for the async query processing"""
171
  print(f"Agent received question (first 50 chars): {question[:50]}...")
172
  try:
 
173
  loop = asyncio.new_event_loop()
174
  asyncio.set_event_loop(loop)
175
  answer = loop.run_until_complete(self.process_query_async(question))
 
180
  print(error_msg)
181
  return error_msg
182
 
183
+
184
  def run_and_submit_all(profile: gr.OAuthProfile | None):
185
  """
186
  Fetches all questions, runs the ResearchAgent on them, submits all answers,