Scott Cogan commited on
Commit
d293a8a
·
1 Parent(s): 7a788c6

revert to working app

Browse files
Files changed (1) hide show
  1. app.py +277 -4
app.py CHANGED
@@ -1,6 +1,279 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return f"Hello, {name}!"
5
 
6
- app = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, Tool, DuckDuckGoSearchTool, load_tool, tool
2
+ from models import OpenAIModel # Import our local model
3
+ import datetime
4
+ import requests
5
+ import pytz
6
+ import yaml
7
+ from tools.final_answer import FinalAnswerTool
8
+ import re
9
+ import os
10
+ import logging
11
+ from jinja2 import Template, StrictUndefined
12
 
13
+ from Gradio_UI import GradioUI
 
14
 
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.DEBUG)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Below is an example of a tool that does nothing. Amaze us with your creativity !
20
+ @tool
21
+ def calculate_min_price(prices: list[float])-> str: #it's import to specify the return type
22
+ """A tool that calculates the min price from list of product prices
23
+ Args:
24
+ prices: list of product prices of
25
+ """
26
+ min_price =min(prices)
27
+ return f"The minimum price is {min_price}"
28
+
29
+ @tool
30
+ def extract_price_from_snippet(snippet: str) -> list[str]:
31
+ """
32
+ A simple function to extract prices from a text snippet using regex.
33
+ You can enhance this function for more complex price extraction.
34
+ Args:
35
+ snippet: text of all prices
36
+ """
37
+ # A basic regular expression to detect common price formats like $29.99, 29.99 USD, etc.
38
+ price_pattern = r'\$\d+(?:,\d{3})*(?:\.\d{2})?|\d+(?:,\d{3})*(?:\.\d{2})?\s*(USD|EUR|GBP|INR|AUD|CAD)?'
39
+ matches = re.findall(price_pattern, snippet)
40
+ matches = [str(x) for x in matches]
41
+ return matches
42
+
43
+
44
+ @tool
45
+ def get_current_time_in_timezone(timezone: str) -> str:
46
+ """A tool that fetches the current local time in a specified timezone.
47
+ Args:
48
+ timezone: A string representing a valid timezone (e.g., 'America/New_York').
49
+ """
50
+ try:
51
+ # Create timezone object
52
+ tz = pytz.timezone(timezone)
53
+ # Get current time in that timezone
54
+ local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
55
+ return f"The current local time in {timezone} is: {local_time}"
56
+ except Exception as e:
57
+ return f"Error fetching time for timezone '{timezone}': {str(e)}"
58
+
59
+
60
+ final_answer = FinalAnswerTool()
61
+
62
+ # Load configuration from agent.json
63
+ with open("agent.json", 'r') as f:
64
+ config = yaml.safe_load(f)
65
+
66
+ # Initialize OpenAI model using the configuration
67
+ model = OpenAIModel(
68
+ model=config["model"]["data"]["model_id"],
69
+ max_tokens=config["model"]["data"]["max_tokens"],
70
+ temperature=config["model"]["data"]["temperature"],
71
+ api_key=os.getenv("OPENAI_API_KEY")
72
+ )
73
+
74
+ # Import tool from Hub
75
+ image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
76
+
77
+ # Load and validate templates from prompts.yaml
78
+ with open("prompts.yaml", 'r') as stream:
79
+ yaml_content = yaml.safe_load(stream)
80
+ if not isinstance(yaml_content, dict):
81
+ raise ValueError("YAML content must be a dictionary")
82
+ if 'prompt_templates' not in yaml_content:
83
+ raise ValueError("YAML must contain 'prompt_templates' key")
84
+ prompt_templates = yaml_content['prompt_templates']
85
+ if not isinstance(prompt_templates, dict):
86
+ raise ValueError("prompt_templates must be a dictionary")
87
+
88
+ logger.debug("Starting template validation...")
89
+ try:
90
+ # Create test tools for validation
91
+ class TestTool(Tool):
92
+ name = "test_tool"
93
+ description = "A test tool"
94
+ inputs = {'input': {'type': 'string', 'description': 'Test input'}}
95
+ output_type = "string"
96
+
97
+ def forward(self, input: str) -> str:
98
+ return "test output"
99
+
100
+ test_tools = [TestTool()] # Create a list of tools
101
+
102
+ # Validate templates
103
+ for key, value in prompt_templates.items():
104
+ if isinstance(value, dict):
105
+ for subkey, subvalue in value.items():
106
+ if isinstance(subvalue, str):
107
+ logger.debug(f"Validating template: {key}.{subkey}")
108
+ logger.debug(f"Template content: {subvalue[:200]}...")
109
+
110
+ # Validate template
111
+ try:
112
+ # Create a template with at least one template node
113
+ template_str = subvalue
114
+ if '{{' not in template_str and '{%' not in template_str:
115
+ template_str = "{{ task }}\n" + template_str
116
+
117
+ # Create a template with the task variable
118
+ template = Template(template_str, undefined=StrictUndefined)
119
+ rendered = template.render(
120
+ tools=test_tools, # Pass list of tools
121
+ task="test",
122
+ name="test",
123
+ final_answer="test",
124
+ remaining_steps=1,
125
+ answer_facts="test"
126
+ )
127
+ logger.debug(f"Template render test successful for {key}.{subkey}")
128
+ except Exception as e:
129
+ logger.error(f"Template render test failed for {key}.{subkey}: {str(e)}")
130
+ raise
131
+ elif isinstance(value, str):
132
+ logger.debug(f"Validating template: {key}")
133
+ logger.debug(f"Template content: {value[:200]}...")
134
+
135
+ # Validate template
136
+ try:
137
+ # Create a template with at least one template node
138
+ template_str = value
139
+ if '{{' not in template_str and '{%' not in template_str:
140
+ template_str = "{{ task }}\n" + template_str
141
+
142
+ # Create a template with the task variable
143
+ template = Template(template_str, undefined=StrictUndefined)
144
+ rendered = template.render(
145
+ tools=test_tools, # Pass list of tools
146
+ task="test",
147
+ name="test",
148
+ final_answer="test",
149
+ remaining_steps=1,
150
+ answer_facts="test"
151
+ )
152
+ logger.debug(f"Template render test successful for {key}")
153
+ except Exception as e:
154
+ logger.error(f"Template render test failed for {key}: {str(e)}")
155
+ raise
156
+
157
+ logger.debug("Template validation completed successfully")
158
+
159
+ except Exception as e:
160
+ logger.error(f"Error during template validation: {str(e)}")
161
+ raise
162
+
163
+ # Create the agent with the templates
164
+ class CalculateMinPriceTool(Tool):
165
+ name = "calculate_min_price"
166
+ description = "Calculate the minimum price from a list of prices"
167
+ inputs = {'prices': {'type': 'array', 'description': 'List of product prices (numbers)'}}
168
+ output_type = "string"
169
+
170
+ def forward(self, prices: list[float]) -> str:
171
+ min_price = min(prices)
172
+ return f"The minimum price is {min_price}"
173
+
174
+ class ExtractPriceFromSnippetTool(Tool):
175
+ name = "extract_price_from_snippet"
176
+ description = "Extract prices from a text snippet"
177
+ inputs = {'snippet': {'type': 'string', 'description': 'Text containing prices'}}
178
+ output_type = "array"
179
+
180
+ def forward(self, snippet: str) -> list[str]:
181
+ price_pattern = r'\$\d+(?:,\d{3})*(?:\.\d{2})?|\d+(?:,\d{3})*(?:\.\d{2})?\s*(USD|EUR|GBP|INR|AUD|CAD)?'
182
+ matches = re.findall(price_pattern, snippet)
183
+ matches = [str(x) for x in matches]
184
+ return matches
185
+
186
+ class GetCurrentTimeInTimezoneTool(Tool):
187
+ name = "get_current_time_in_timezone"
188
+ description = "Get the current time in a specified timezone"
189
+ inputs = {'timezone': {'type': 'string', 'description': 'A valid timezone (e.g., America/New_York)'}}
190
+ output_type = "string"
191
+
192
+ def forward(self, timezone: str) -> str:
193
+ try:
194
+ tz = pytz.timezone(timezone)
195
+ local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
196
+ return f"The current local time in {timezone} is: {local_time}"
197
+ except Exception as e:
198
+ return f"Error fetching time for timezone '{timezone}': {str(e)}"
199
+
200
+ # Create the agent with the templates
201
+ tools = [
202
+ final_answer, # This is already a Tool instance
203
+ DuckDuckGoSearchTool(), # This is already a Tool instance
204
+ CalculateMinPriceTool(), # This is a Tool subclass
205
+ ExtractPriceFromSnippetTool(), # This is a Tool subclass
206
+ GetCurrentTimeInTimezoneTool() # This is a Tool subclass
207
+ ]
208
+
209
+ # Debug prints to inspect tools
210
+ logger.debug("Inspecting tools list:")
211
+ for i, tool in enumerate(tools):
212
+ logger.debug(f"Tool {i}: {tool}")
213
+ logger.debug(f" Type: {type(tool)}")
214
+ logger.debug(f" Is Tool instance: {isinstance(tool, Tool)}")
215
+ logger.debug(f" Has name: {hasattr(tool, 'name')}")
216
+ logger.debug(f" Has description: {hasattr(tool, 'description')}")
217
+ if hasattr(tool, 'name'):
218
+ logger.debug(f" Name: {tool.name}")
219
+ if hasattr(tool, 'description'):
220
+ logger.debug(f" Description: {tool.description}")
221
+
222
+ # Verify all tools are proper Tool instances and have required attributes
223
+ for tool in tools:
224
+ if not isinstance(tool, Tool):
225
+ raise TypeError(f"Tool {tool} is not an instance of Tool or its subclass")
226
+ if not hasattr(tool, 'name') or not hasattr(tool, 'description'):
227
+ raise AttributeError(f"Tool {tool} is missing required attributes (name or description)")
228
+
229
+ # Create a dictionary of tools for template rendering
230
+ tools_dict = {tool.name: tool for tool in tools}
231
+ logger.debug("Created tools dictionary:")
232
+ for name, tool in tools_dict.items():
233
+ logger.debug(f" {name}: {tool}")
234
+
235
+ print("[DEBUG] Starting app.py initialization...")
236
+
237
+ # Create a custom CodeAgent class that handles both list and dictionary requirements
238
+ class CustomCodeAgent(CodeAgent):
239
+ def __init__(self, *args, **kwargs):
240
+ print("[DEBUG] Initializing CustomCodeAgent...")
241
+ # Store the tools dictionary for template rendering
242
+ self.tools_dict = kwargs.pop('tools_dict', {})
243
+ # Initialize parent class first
244
+ super().__init__(*args, **kwargs)
245
+
246
+ def initialize_system_prompt(self):
247
+ print("[DEBUG] Initializing system prompt...")
248
+ # Override to use tools_dict for template rendering
249
+ template = self.prompt_templates["system_prompt"]
250
+ # Convert tools_dict to list for template rendering
251
+ tools_list = list(self.tools_dict.values())
252
+ # Use Jinja2 template rendering
253
+ return Template(template, undefined=StrictUndefined).render(
254
+ tools=tools_list, # Pass tools as a list
255
+ task=getattr(self, 'task', ''),
256
+ managed_agents=getattr(self, 'managed_agents', []),
257
+ authorized_imports=getattr(self, 'authorized_imports', [])
258
+ )
259
+
260
+ print("[DEBUG] Creating agent instance...")
261
+ agent = CustomCodeAgent(
262
+ model=model,
263
+ tools=tools, # Pass tools as a list for _setup_tools
264
+ tools_dict=tools_dict, # Pass tools as a dictionary for template rendering
265
+ max_steps=15,
266
+ verbosity_level=2,
267
+ grammar=None,
268
+ planning_interval=1,
269
+ name="question_answering_agent",
270
+ description="An agent specialized in answering various types of questions using available tools. The agent must use the final_answer tool to submit its answer.",
271
+ prompt_templates=prompt_templates
272
+ )
273
+
274
+ print("[DEBUG] Building Gradio demo...")
275
+ demo = GradioUI(agent).build_blocks()
276
+ print("[DEBUG] Gradio demo built and exposed as 'demo'.")
277
+ print("[DEBUG] About to assign app = demo")
278
+ app = demo
279
+ print("[DEBUG] app assigned successfully")