prige commited on
Commit
520b75d
·
verified ·
1 Parent(s): 3fbf461

Upload tool

Browse files
Files changed (1) hide show
  1. tool.py +27 -5
tool.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any, Optional
2
  from smolagents.tools import Tool
 
3
 
4
  class UserInputTool(Tool):
5
  name = "user_input"
@@ -7,9 +8,30 @@ class UserInputTool(Tool):
7
  inputs = {'question': {'type': 'string', 'description': 'The question to ask the user'}}
8
  output_type = "string"
9
 
10
- def forward(self, question):
11
- user_input = input(f"{question} => Type your answer here:")
12
- return user_input
 
 
 
13
 
14
- def __init__(self, *args, **kwargs):
15
- self.is_initialized = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Any, Optional
2
  from smolagents.tools import Tool
3
+ import logging
4
 
5
  class UserInputTool(Tool):
6
  name = "user_input"
 
8
  inputs = {'question': {'type': 'string', 'description': 'The question to ask the user'}}
9
  output_type = "string"
10
 
11
+ def __init__(self):
12
+ super().__init__()
13
+ import logging
14
+ self.logger = logging.getLogger(__name__)
15
+ self.logger.setLevel(logging.INFO)
16
+ self.user_input = "" # Initialize user_input as an instance variable
17
 
18
+ def _validate_question(self, question) -> tuple[bool, str]:
19
+ # Helper method to validate the question
20
+ if not isinstance(question, str):
21
+ return False, f"Question must be a string, got {type(question)}"
22
+ if not question.strip():
23
+ return False, "Question cannot be empty"
24
+ return True, question
25
+
26
+ async def forward(self, question: str) -> str:
27
+ # Validate the question first
28
+ success, response = self._validate_question(question)
29
+ if not success:
30
+ self.logger.error(response)
31
+ return f"Error: {response}"
32
+
33
+ # Ask the validated question
34
+ self.logger.info(f"Asking user: {question}")
35
+ self.user_input = input(f"{question} => Type your answer here:") # Use instance variable
36
+ self.logger.info(f"Received user input: {self.user_input}")
37
+ return self.user_input