Soumik Bose commited on
Commit
548fe9b
·
1 Parent(s): 046b2da
Files changed (1) hide show
  1. gemini_langchain_agent.py +22 -27
gemini_langchain_agent.py CHANGED
@@ -3,8 +3,7 @@ import uuid
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  import pandas as pd
5
  from langchain_core.prompts import ChatPromptTemplate
6
- # Removed the import for PythonAstREPLTool as it causes a duplicate tool declaration
7
- # from langchain_experimental.tools import PythonAstREPLTool
8
  from langchain_experimental.agents import create_pandas_dataframe_agent
9
  from dotenv import load_dotenv
10
  import numpy as np
@@ -27,18 +26,15 @@ llm_instances = [
27
  ]
28
  current_instance_index = 0 # Track current instance being used
29
 
30
- # Modified create_agent function: Removed the 'tools' parameter and 'extra_tools' argument
31
- def create_agent(llm, data):
32
- """Create agent with tool names. create_pandas_dataframe_agent typically includes its own Python REPL tool."""
33
  return create_pandas_dataframe_agent(
34
  llm,
35
  data,
36
  agent_type="tool-calling",
37
  verbose=True,
38
  allow_dangerous_code=True,
39
- # The PythonAstREPLTool was causing a duplicate declaration.
40
- # create_pandas_dataframe_agent is expected to provide its own REPL for DataFrame interaction.
41
- # extra_tools=tools,
42
  return_intermediate_steps=True
43
  )
44
 
@@ -49,7 +45,7 @@ def _prompt_generator(question: str, chart_required: bool, csv_url: str):
49
  2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
50
  3. **Communication:** Provide concise, professional, and well-structured responses.
51
  4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
52
- 5. Always use pd.read_csv("{csv_url}") to read the CSV file.
53
 
54
  **Query:** {question}
55
  """
@@ -84,7 +80,7 @@ def _prompt_generator(question: str, chart_required: bool, csv_url: str):
84
  - Use THE SAME unique_id throughout entire process
85
  - NEVER generate new UUIDs after initial creation
86
  - Return EXACT filepath string of the final saved chart
87
- - Always use pd.read_csv("{csv_url}") to read the CSV file
88
  """
89
 
90
  if chart_required:
@@ -102,23 +98,21 @@ def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bo
102
  llm = llm_instances[current_instance_index]
103
  print(f"Using LLM instance index {current_instance_index}")
104
 
105
- # Removed explicit creation of PythonAstREPLTool
106
- # The create_pandas_dataframe_agent is expected to internally handle Python execution for DataFrame operations.
107
- # tool = PythonAstREPLTool(
108
- # locals={
109
- # "df": data,
110
- # "pd": pd,
111
- # "np": np,
112
- # "plt": plt,
113
- # "sns": sns,
114
- # "matplotlib": matplotlib,
115
- # "uuid": uuid,
116
- # "dt": dt
117
- # },
118
- # )
119
 
120
- # Modified create_agent call: Removed the 'tools' argument
121
- agent = create_agent(llm, data)
122
  prompt = _prompt_generator(question, chart_required, csv_url)
123
  result = agent.invoke({"input": prompt})
124
  output = result.get("output")
@@ -133,4 +127,5 @@ def langchain_gemini_csv_handler(csv_url: str, question: str, chart_required: bo
133
  current_instance_index += 1
134
 
135
  print("All LLM instances have been exhausted.")
136
- return None
 
 
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  import pandas as pd
5
  from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_experimental.tools import PythonAstREPLTool
 
7
  from langchain_experimental.agents import create_pandas_dataframe_agent
8
  from dotenv import load_dotenv
9
  import numpy as np
 
26
  ]
27
  current_instance_index = 0 # Track current instance being used
28
 
29
+ def create_agent(llm, data, tools):
30
+ """Create agent with tool names"""
 
31
  return create_pandas_dataframe_agent(
32
  llm,
33
  data,
34
  agent_type="tool-calling",
35
  verbose=True,
36
  allow_dangerous_code=True,
37
+ extra_tools=tools,
 
 
38
  return_intermediate_steps=True
39
  )
40
 
 
45
  2. **Data Integrity:** Ensure proper handling of null values to maintain accuracy and reliability.
46
  3. **Communication:** Provide concise, professional, and well-structured responses.
47
  4. Avoid including any internal processing details or references to the methods used to generate your response (ex: based on the tool call, using the function -> These types of phrases.)
48
+ 5. Always use pd.read_csv({csv_url}) to read the CSV file.
49
 
50
  **Query:** {question}
51
  """
 
80
  - Use THE SAME unique_id throughout entire process
81
  - NEVER generate new UUIDs after initial creation
82
  - Return EXACT filepath string of the final saved chart
83
+ - Always use pd.read_csv({csv_url}) to read the CSV file
84
  """
85
 
86
  if chart_required:
 
98
  llm = llm_instances[current_instance_index]
99
  print(f"Using LLM instance index {current_instance_index}")
100
 
101
+ # Create tool with validated name
102
+ tool = PythonAstREPLTool(
103
+ locals={
104
+ "df": data,
105
+ "pd": pd,
106
+ "np": np,
107
+ "plt": plt,
108
+ "sns": sns,
109
+ "matplotlib": matplotlib,
110
+ "uuid": uuid,
111
+ "dt": dt
112
+ },
113
+ )
 
114
 
115
+ agent = create_agent(llm, data, [tool])
 
116
  prompt = _prompt_generator(question, chart_required, csv_url)
117
  result = agent.invoke({"input": prompt})
118
  output = result.get("output")
 
127
  current_instance_index += 1
128
 
129
  print("All LLM instances have been exhausted.")
130
+ return None
131
+