prige commited on
Commit
371df9f
·
verified ·
1 Parent(s): fad6849

Upload tool

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. tool.py +54 -76
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  smolagents
2
- requests
 
1
  smolagents
2
+ aiohttp
tool.py CHANGED
@@ -1,8 +1,9 @@
1
  from typing import Any, Optional
2
  from smolagents.tools import Tool
3
- import json
4
- import requests
5
  import logging
 
 
6
 
7
  class UserLocationTool(Tool):
8
  name = "get_user_location"
@@ -10,88 +11,65 @@ class UserLocationTool(Tool):
10
  inputs = {}
11
  output_type = "string"
12
 
13
- def __init__(self):
14
- # Calls the constructor of the parent class `Tool`
15
  super().__init__()
16
- self.dependencies = ["requests"]
17
-
18
- def _validate_location_data(self, data: dict) -> str | None:
19
- """Validates the location data and returns error message if invalid."""
20
-
21
- required_fields = ['city', 'region', 'country']
22
-
23
- if not isinstance(data, dict):
24
- return f"Expected dictionary data, got {type(data)}"
25
-
26
- missing_fields = [field for field in required_fields if field not in data]
27
- if missing_fields:
28
- return f"Missing required location fields: {', '.join(missing_fields)}"
29
-
30
- empty_fields = [field for field in required_fields if not data.get(field)]
31
- if empty_fields:
32
- return f"Empty values for fields: {', '.join(empty_fields)}"
33
-
34
- return None
35
-
36
- def forward(self) -> str:
37
- # Import inside method as per SmolAgents requirements
38
- import requests
39
- import json
40
  import logging
 
 
41
 
42
- logger = logging.getLogger(__name__)
43
- logger.setLevel(logging.INFO)
44
-
45
- try:
46
- logger.info("Initiating request to ipinfo.io")
47
- response = requests.get("https://ipinfo.io/", timeout=5)
48
-
49
- # Handle rate limiting
50
- if response.status_code == 429:
51
- error_msg = "Rate limit exceeded for IP geolocation service"
52
- logger.error(error_msg)
53
- return f"Error getting user location: {error_msg}. Please try again later."
54
-
55
- # Handle other HTTP errors
56
- if response.status_code != 200:
57
- error_msg = f"HTTP {response.status_code}: {response.reason}"
58
- logger.error(f"ipinfo.io request failed: {error_msg}")
59
- return f"Error getting user location: Server returned {error_msg}"
60
 
61
- # check if the response is valid JSON
62
- try:
63
- data = response.json()
64
- except json.JSONDecodeError as e:
65
- error_msg = f"Invalid JSON response: {str(e)}"
66
- logger.error(error_msg)
67
- return f"Error getting user location: {error_msg}"
68
 
69
- # check if the data contains: city, region, and country
70
- validation_error = self._validate_location_data(data)
71
- if validation_error:
72
- logger.error(f"Data validation failed: {validation_error}")
73
- return f"Error getting user location: {validation_error}"
74
 
75
- location = f"According to your IP address, the user's location is {data['city']}, {data['region']}, {data['country']}."
76
- logger.info(f"Successfully retrieved location: {location}")
77
- return location
78
 
79
- except requests.exceptions.Timeout:
80
- error_msg = "Request timed out after 5 seconds"
81
- logger.error(error_msg)
82
- return f"Error getting user location: {error_msg}. Please try again."
83
 
84
- except requests.exceptions.ConnectionError as e:
85
- error_msg = f"Connection failed: {str(e)}"
86
- logger.error(error_msg)
87
- return f"Error getting user location: Unable to connect to location service. Please check your internet connection."
88
 
89
- except requests.exceptions.RequestException as e:
90
- error_msg = f"Network error: {str(e)}"
91
- logger.error(error_msg)
92
- return f"Error getting user location: {error_msg}"
93
 
 
 
 
 
 
 
 
 
94
  except Exception as e:
95
- error_msg = f"Unexpected error: {str(e)}"
96
- logger.error(error_msg)
97
- return f"Error getting user location: An unexpected error occurred. Please try again later."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Any, Optional
2
  from smolagents.tools import Tool
3
+ import asyncio
 
4
  import logging
5
+ import json
6
+ import aiohttp
7
 
8
  class UserLocationTool(Tool):
9
  name = "get_user_location"
 
11
  inputs = {}
12
  output_type = "string"
13
 
14
+ def __init__(self): # Initialize the logger
 
15
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import logging
17
+ self.logger = logging.getLogger(__name__)
18
+ self.logger.setLevel(logging.INFO)
19
 
20
+ def _validate_location_data(self, response: Any) -> tuple[bool, str]:
21
+ # Validates the API response, JSON parsing, and location data. Returns (is_valid, error_message or location_string)
22
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ if response.status != 200: # Check HTTP status
25
+ if response.status == 429:
26
+ return False, "Rate limit exceeded for IP geolocation service. Please try again later."
27
+ return False, f"Server returned HTTP {response.status}"
 
 
 
28
 
29
+ try: # Parse and validate JSON
30
+ response_data = response.json()
31
+ except json.JSONDecodeError:
32
+ return False, "Invalid response from server"
 
33
 
34
+ if not isinstance(response_data, dict): # Validate response data type
35
+ return False, f"Expected dictionary data, got {type(response_data)}"
 
36
 
37
+ required_fields = ['city', 'region', 'country'] # Validate required fields
38
+ missing_fields = [field for field in required_fields if not response_data.get(field)]
39
+ if missing_fields:
40
+ return False, f"Missing or empty location fields: {', '.join(missing_fields)}"
41
 
42
+ location = f"According to your IP address, the user's location is {response_data['city']}, {response_data['region']}, {response_data['country']}."
43
+ return True, location
 
 
44
 
45
+ async def _call_api(self) -> tuple[bool, Any]:
46
+ # Makes the API call to ipinfo.io. Returns (success, response_object_or_error_message)
47
+ import aiohttp
48
+ import asyncio
49
 
50
+ try:
51
+ async with aiohttp.ClientSession() as session:
52
+ async with session.get("https://ipinfo.io/", timeout=5) as response:
53
+ return True, response
54
+ except asyncio.TimeoutError:
55
+ return False, "Request timed out after 5 seconds. Please try again."
56
+ except aiohttp.ClientError as e:
57
+ return False, "Unable to connect to location service. Please check your internet connection."
58
  except Exception as e:
59
+ return False, f"An unexpected error occurred: {str(e)}"
60
+
61
+ async def forward(self) -> str:
62
+ # Send the request to the API
63
+ success, response = await self._call_api()
64
+ if not success:
65
+ self.logger.error(response)
66
+ return f"Error getting user location: {response}"
67
+
68
+ # Validate response format
69
+ is_valid, result = self._validate_location_data(response)
70
+ if not is_valid:
71
+ self.logger.error(f"Data validation failed: {result}")
72
+ return f"Error getting user location: {result}"
73
+
74
+ self.logger.info(f"Successfully retrieved location: {result}")
75
+ return result