prige commited on
Commit
2fd68f6
·
verified ·
1 Parent(s): c0163a7

Upload tool

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. tool.py +80 -5
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- smolagents
2
  requests
 
 
 
1
  requests
2
+ smolagents
tool.py CHANGED
@@ -1,6 +1,9 @@
1
  from typing import Any, Optional
2
  from smolagents.tools import Tool
 
 
3
  import requests
 
4
 
5
  class UserLocationTool(Tool):
6
  name = "get_user_location"
@@ -9,16 +12,88 @@ class UserLocationTool(Tool):
9
  output_type = "string"
10
 
11
  def __init__(self):
 
12
  super().__init__()
13
  self.dependencies = ["requests"]
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def forward(self) -> str:
16
- # Import requests inside the function as per requirements
17
  import requests
 
 
 
 
 
18
 
19
  try:
20
- response = requests.get("https://ipinfo.io/")
21
- data = response.json()
22
- return data["city"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  except Exception as e:
24
- return f"Error getting user location: {str(e)}. It looks like https://ipinfo.io/ did not return a valid response."
 
 
 
1
  from typing import Any, Optional
2
  from smolagents.tools import Tool
3
+ import json
4
+ import typing
5
  import requests
6
+ import logging
7
 
8
  class UserLocationTool(Tool):
9
  name = "get_user_location"
 
12
  output_type = "string"
13
 
14
  def __init__(self):
15
+ # Calls the constructor of the parent class `Tool`
16
  super().__init__()
17
  self.dependencies = ["requests"]
18
 
19
+ def _validate_location_data(self, data: dict) -> str | None:
20
+ """Validates the location data and returns error message if invalid."""
21
+ # Import inside method as per SmolAgents requirements
22
+ from typing import Optional
23
+
24
+ required_fields = ['city', 'region', 'country']
25
+
26
+ if not isinstance(data, dict):
27
+ return f"Expected dictionary data, got {type(data)}"
28
+
29
+ missing_fields = [field for field in required_fields if field not in data]
30
+ if missing_fields:
31
+ return f"Missing required location fields: {', '.join(missing_fields)}"
32
+
33
+ empty_fields = [field for field in required_fields if not data.get(field)]
34
+ if empty_fields:
35
+ return f"Empty values for fields: {', '.join(empty_fields)}"
36
+
37
+ return None
38
+
39
  def forward(self) -> str:
40
+ # Import inside method as per SmolAgents requirements
41
  import requests
42
+ import json
43
+ import logging
44
+
45
+ logger = logging.getLogger(__name__)
46
+ logger.setLevel(logging.INFO)
47
 
48
  try:
49
+ logger.info("Initiating request to ipinfo.io")
50
+ response = requests.get("https://ipinfo.io/", timeout=5)
51
+
52
+ # Handle rate limiting
53
+ if response.status_code == 429:
54
+ error_msg = "Rate limit exceeded for IP geolocation service"
55
+ logger.error(error_msg)
56
+ return f"Error getting user location: {error_msg}. Please try again later."
57
+
58
+ # Handle other HTTP errors
59
+ if response.status_code != 200:
60
+ error_msg = f"HTTP {response.status_code}: {response.reason}"
61
+ logger.error(f"ipinfo.io request failed: {error_msg}")
62
+ return f"Error getting user location: Server returned {error_msg}"
63
+
64
+ try:
65
+ data = response.json()
66
+ except json.JSONDecodeError as e:
67
+ error_msg = f"Invalid JSON response: {str(e)}"
68
+ logger.error(error_msg)
69
+ return f"Error getting user location: {error_msg}"
70
+
71
+ # Validate the response data
72
+ validation_error = self._validate_location_data(data)
73
+ if validation_error:
74
+ logger.error(f"Data validation failed: {validation_error}")
75
+ return f"Error getting user location: {validation_error}"
76
+
77
+ location = f"According to your IP address, the user's location is {data['city']}, {data['region']}, {data['country']}."
78
+ logger.info(f"Successfully retrieved location: {location}")
79
+ return location
80
+
81
+ except requests.exceptions.Timeout:
82
+ error_msg = "Request timed out after 5 seconds"
83
+ logger.error(error_msg)
84
+ return f"Error getting user location: {error_msg}. Please try again."
85
+
86
+ except requests.exceptions.ConnectionError as e:
87
+ error_msg = f"Connection failed: {str(e)}"
88
+ logger.error(error_msg)
89
+ return f"Error getting user location: Unable to connect to location service. Please check your internet connection."
90
+
91
+ except requests.exceptions.RequestException as e:
92
+ error_msg = f"Network error: {str(e)}"
93
+ logger.error(error_msg)
94
+ return f"Error getting user location: {error_msg}"
95
+
96
  except Exception as e:
97
+ error_msg = f"Unexpected error: {str(e)}"
98
+ logger.error(error_msg)
99
+ return f"Error getting user location: An unexpected error occurred. Please try again later."