File size: 3,190 Bytes
c0163a7
 
26e3b0f
50d6f1a
d6b1954
 
c0163a7
 
 
 
 
 
 
371df9f
c0163a7
2fd68f6
371df9f
 
2fd68f6
371df9f
 
 
2fd68f6
371df9f
 
 
 
2fd68f6
371df9f
 
 
 
2fd68f6
371df9f
 
2fd68f6
371df9f
 
 
 
2fd68f6
371df9f
 
2fd68f6
371df9f
 
 
 
2fd68f6
371df9f
 
 
 
 
 
 
 
c0163a7
371df9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Any, Optional
from smolagents.tools import Tool
import json
import asyncio
import logging
import aiohttp

class UserLocationTool(Tool):
    name = "get_user_location"
    description = "This tool returns the user's location based on their IP address. That is all it does."
    inputs = {}
    output_type = "string"

    def __init__(self): # Initialize the logger
        super().__init__()
        import logging
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

    def _validate_location_data(self, response: Any) -> tuple[bool, str]:
        # Validates the API response, JSON parsing, and location data. Returns (is_valid, error_message or location_string)
        import json

        if response.status != 200: # Check HTTP status
            if response.status == 429:
                return False, "Rate limit exceeded for IP geolocation service. Please try again later."
            return False, f"Server returned HTTP {response.status}"

        try: # Parse and validate JSON
            response_data = response.json()
        except json.JSONDecodeError:
            return False, "Invalid response from server"

        if not isinstance(response_data, dict): # Validate response data type
            return False, f"Expected dictionary data, got {type(response_data)}"

        required_fields = ['city', 'region', 'country'] # Validate required fields
        missing_fields = [field for field in required_fields if not response_data.get(field)]
        if missing_fields:
            return False, f"Missing or empty location fields: {', '.join(missing_fields)}"

        location = f"According to your IP address, the user's location is {response_data['city']}, {response_data['region']}, {response_data['country']}."
        return True, location

    async def _call_api(self) -> tuple[bool, Any]:
        # Makes the API call to ipinfo.io. Returns (success, response_object_or_error_message)
        import aiohttp
        import asyncio

        try:
            async with aiohttp.ClientSession() as session:
                async with session.get("https://ipinfo.io/", timeout=5) as response:
                    return True, response
        except asyncio.TimeoutError:
            return False, "Request timed out after 5 seconds. Please try again."
        except aiohttp.ClientError as e:
            return False, "Unable to connect to location service. Please check your internet connection."
        except Exception as e:
            return False, f"An unexpected error occurred: {str(e)}"

    async def forward(self) -> str:
        # Send the request to the API
        success, response = await self._call_api()
        if not success:
            self.logger.error(response)
            return f"Error getting user location: {response}"

        # Validate response format
        is_valid, result = self._validate_location_data(response)
        if not is_valid:
            self.logger.error(f"Data validation failed: {result}")
            return f"Error getting user location: {result}"

        self.logger.info(f"Successfully retrieved location: {result}")
        return result