File size: 6,153 Bytes
18a33dc
9b5b26a
18a33dc
97a8a05
9b5b26a
 
c19d193
6aae614
9b5b26a
 
97a8a05
 
 
 
 
 
 
 
 
 
 
 
18a33dc
97a8a05
 
 
 
 
 
 
 
 
 
 
 
 
 
18a33dc
97a8a05
 
 
18a33dc
97a8a05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a33dc
97a8a05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b5b26a
 
 
 
 
 
 
 
 
 
 
 
8c01ffb
 
6aae614
ae7a494
e121372
18a33dc
 
 
 
13d500a
8c01ffb
9b5b26a
 
8c01ffb
861422e
 
18a33dc
8c01ffb
8fe992b
18a33dc
 
 
 
 
8c01ffb
 
 
 
 
 
861422e
8fe992b
 
8c01ffb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, load_tool, tool
import datetime
import re
from difflib import get_close_matches
import requests
import pytz
import yaml
from tools.final_answer import FinalAnswerTool
from Gradio_UI import GradioUI

MIN_YEAR = 2023

def sanitize_name(name: str) -> str | None:
    """Only allow letters, hyphens, apostrophes and spaces (covers names like O'Ward, Häkkinen)."""
    if not re.match(r"^[a-zA-ZÀ-ÿ\s'\-]{2,50}$", name):
        return None
    return name.strip()

def sanitize_year(year: int) -> int | None:
    """Only allow integers within a valid range."""
    if not isinstance(year, int) or isinstance(year, bool):
        return None
    if year < MIN_YEAR or year > datetime.datetime.now().year:  # fix 1
        return None
    return year

@tool
def driver_number(name: str, year: int) -> str:
    """A tool that give the number of a F1 driver for a specific session  
    Args:
        name: the name of the driver
        year: the year for which we want the number from 2023
    """

    # Sanitize inputs before anything else
    clean_name = sanitize_name(name)
    if clean_name is None:
        return "Invalid driver name. Only letters, spaces, hyphens and apostrophes are allowed (2-50 characters)."

    clean_year = sanitize_year(year)
    if clean_year is None:
        return f"Invalid year. Please provide a year between {MIN_YEAR} and {datetime.datetime.now().year}."  # fix 2

    try:
        # Step 1: Get any session_key for the given year
        sessions_resp = requests.get(
            "https://api.openf1.org/v1/sessions",
            params={"year": clean_year},
            timeout=10
        )

        if sessions_resp.status_code != 200:
            return f"Error fetching sessions: HTTP {sessions_resp.status_code}"

        sessions = sessions_resp.json()

        if not sessions:
            return f"No sessions found for year {clean_year}."

        session_key = sessions[0]["session_key"]

        # Validate session_key is an integer before using it
        if not isinstance(session_key, int):
            return "Unexpected session data returned from API."

        # Step 2: Search for the driver by last name within that session
        drivers_resp = requests.get(
            "https://api.openf1.org/v1/drivers",
            params={"session_key": session_key, "last_name": clean_name},
            timeout=10
        )

        if drivers_resp.status_code != 200:
            return f"Error fetching drivers: HTTP {drivers_resp.status_code}"

        drivers = drivers_resp.json()

        # Fallback 1: try full name exact match
        if not drivers:
            drivers_resp = requests.get(
                "https://api.openf1.org/v1/drivers",
                params={"session_key": session_key, "full_name": clean_name.upper()},
                timeout=10
            )
            drivers = drivers_resp.json() if drivers_resp.status_code == 200 else []

        # Fallback 2: fuzzy match against all drivers in that session
        if not drivers:
            all_drivers_resp = requests.get(
                "https://api.openf1.org/v1/drivers",
                params={"session_key": session_key},
                timeout=10
            )

            if all_drivers_resp.status_code == 200:
                all_drivers = all_drivers_resp.json()

                # Build a mapping of searchable names -> driver object
                name_map = {}
                for d in all_drivers:
                    name_map[d.get("last_name", "").lower()] = d
                    name_map[d.get("full_name", "").lower()] = d
                    name_map[d.get("broadcast_name", "").lower()] = d
                    name_map[d.get("name_acronym", "").lower()] = d

                matches = get_close_matches(
                    clean_name.lower(),
                    name_map.keys(),
                    n=1,
                    cutoff=0.6
                )

                if matches:
                    drivers = [name_map[matches[0]]]

        if not drivers:
            return f"No driver found with name '{clean_name}' in {clean_year}. Check spelling or try the driver's last name only."

        # Validate the returned driver_number is actually an integer
        driver_num = drivers[0].get("driver_number")
        if not isinstance(driver_num, int):
            return "Unexpected driver data returned from API."

        return str(driver_num)

    except requests.exceptions.Timeout:
        return "Request timed out. The OpenF1 API may be slow or unavailable."
    except requests.exceptions.ConnectionError:
        return "Could not connect to the OpenF1 API. Check network connectivity."
    except Exception as e:
        return f"Unexpected error: {str(e)}"

@tool
def get_current_time_in_timezone(timezone: str) -> str:
    """A tool that fetches the current local time in a specified timezone.
    Args:
        timezone: A string representing a valid timezone (e.g., 'America/New_York').
    """
    try:
        tz = pytz.timezone(timezone)
        local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
        return f"The current local time in {timezone} is: {local_time}"
    except Exception as e:
        return f"Error fetching time for timezone '{timezone}': {str(e)}"


final_answer = FinalAnswerTool()

model = HfApiModel(
    max_tokens=2096,
    temperature=0.5,
    model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
    custom_role_conversions=None,
)

# Import tool from Hub
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)

with open("prompts.yaml", 'r') as stream:
    prompt_templates = yaml.safe_load(stream)

agent = CodeAgent(
    model=model,
    tools=[
        final_answer,                 # always keep this one
        get_current_time_in_timezone, # pre-made timezone tool
        driver_number,                # your custom F1 tool
    ],
    max_steps=6,
    verbosity_level=1,
    grammar=None,
    planning_interval=None,
    name=None,
    description=None,
    prompt_templates=prompt_templates
)

GradioUI(agent).launch()