PenaltyGPT / app.py
Gabriel382's picture
Update app.py
a094f5c verified
import os
import math
import datetime
import pytz
import yaml
from PIL import Image, ImageDraw, ImageFont
from typing import Optional
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, load_tool, tool
from tools.final_answer import FinalAnswerTool
from Gradio_UI import GradioUI
# -------------------------------
# Example tool (provided)
# -------------------------------
@tool
def my_custom_tool(arg1: str, arg2: int) -> str:
"""A tool that does nothing yet.
Args:
arg1: the first argument.
arg2: the second argument.
"""
return "What magic will you build ?"
@tool
def get_current_time_in_timezone(timezone: str) -> str:
"""Fetch the current local time in a specified timezone.
Args:
timezone: A string representing a valid timezone (e.g., 'America/New_York').
"""
try:
tz = pytz.timezone(timezone)
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"The current local time in {timezone} is: {local_time}"
except Exception as e:
return f"Error fetching time for timezone '{timezone}': {str(e)}"
# -------------------------------
# Penalty Simulation Tool
# -------------------------------
@tool
def simulate_penalty(shot_description: str, probability: float, posX: float, posY: float) -> Image.Image:
"""
Simulate a penalty shot using explicit parameters provided by the agent.
Args:
shot_description: Description of how the ball is shot.
probability: A float representing the chance of scoring.
posX: The horizontal coordinate (0 to 1) where the ball goes.
posY: The vertical coordinate (0 to 1) where the ball goes.
Coordinate System & Rules:
- The goal’s coordinate system: upper right corner is (0, 0) and bottom left is (1, 1).
- The ball is out-of-bounds if posX or posY are not in [0, 1].
- If posX is exactly 1 or posY is exactly 0, the ball touches the crossbar → shot is a miss.
- Otherwise, if probability >= 0.5 the shot is a goal; if probability < 0.5, it’s a miss.
Visual Output:
- Composes an image from five PNGs (in the "images" folder):
1. Stadium background.
2. Goal.
3. Goalkeeper (always stretching/reaching).
4. Ball (positioned using posX and posY relative to the goal).
5. Striker.
- If probability < 0.5, the striker image is rotated and repositioned to simulate reaching for the ball.
- A text message showing the shot result and parameter values is overlaid on the image.
Returns:
The composite PIL Image that Gradio will display.
"""
# Validate shot based on the rules.
in_bounds = (0 <= posX <= 1) and (0 <= posY <= 1)
touches_crossbar = (abs(posX - 1) < 1e-6) or (abs(posY - 0) < 1e-6)
if (not in_bounds) or touches_crossbar:
shot_result = "Miss!"
else:
shot_result = "Goal!" if probability >= 0.5 else "Miss!"
# Compose the composite image.
base_dir = os.path.join(os.getcwd(), "images")
stadium_path = os.path.join(base_dir, "stadium.png")
goal_path = os.path.join(base_dir, "goal.png")
goalkeeper_path = os.path.join(base_dir, "goalkeeper.png")
ball_path = os.path.join(base_dir, "ball.png")
striker_path = os.path.join(base_dir, "striker.png")
try:
stadium_img = Image.open(stadium_path).convert("RGBA")
goal_img = Image.open(goal_path).convert("RGBA")
goalkeeper_img = Image.open(goalkeeper_path).convert("RGBA")
ball_img = Image.open(ball_path).convert("RGBA")
striker_img = Image.open(striker_path).convert("RGBA")
except Exception as e:
# Return a blank image with error text if images fail to load.
error_img = Image.new("RGBA", (800, 600), (255, 0, 0, 255))
draw = ImageDraw.Draw(error_img)
draw.text((50, 50), f"Error loading images: {str(e)}", fill="white")
return error_img
# Start with the stadium background.
canvas = stadium_img.copy()
draw = ImageDraw.Draw(canvas)
# Position the goal on the canvas (example: at (300, 100)).
goal_pos = (300, 100)
canvas.paste(goal_img, goal_pos, goal_img)
goal_width, goal_height = goal_img.size
# Place the goalkeeper at the center of the goal.
goalkeeper_width, goalkeeper_height = goalkeeper_img.size
goalkeeper_pos = (goal_pos[0] + (goal_width - goalkeeper_width) // 2,
goal_pos[1] + (goal_height - goalkeeper_height) // 2)
canvas.paste(goalkeeper_img, goalkeeper_pos, goalkeeper_img)
# Transform shot coordinates into pixel positions:
# Horizontal: posX=0 is right edge, posX=1 is left edge.
# Vertical: posY=0 is top, posY=1 is bottom.
ball_width, ball_height = ball_img.size
ball_center_x = goal_pos[0] + (1 - posX) * goal_width
ball_center_y = goal_pos[1] + posY * goal_height
ball_pos = (int(ball_center_x - ball_width / 2), int(ball_center_y - ball_height / 2))
canvas.paste(ball_img, ball_pos, ball_img)
# Position the striker.
striker_width, striker_height = striker_img.size
canvas_width, canvas_height = canvas.size
default_striker_pos = (50, canvas_height - striker_height - 50)
if probability >= 0.5:
final_striker_img = striker_img
striker_pos = default_striker_pos
else:
dx = ball_center_x - default_striker_pos[0]
dy = ball_center_y - default_striker_pos[1]
angle_rad = math.atan2(dy, dx)
angle_deg = math.degrees(angle_rad)
final_striker_img = striker_img.rotate(-angle_deg, expand=True)
new_width, new_height = final_striker_img.size
striker_pos = (int(ball_center_x - new_width - 10), int(ball_center_y - new_height / 2))
canvas.paste(final_striker_img, striker_pos, final_striker_img)
# Draw the shot result text.
try:
font = ImageFont.truetype("arial.ttf", 40)
except IOError:
font = ImageFont.load_default()
text = f"{shot_result} (P={probability:.2f}, X={posX:.2f}, Y={posY:.2f})"
text_position = (50, 20)
# Draw text shadow.
draw.text((text_position[0]+2, text_position[1]+2), text, font=font, fill="black")
draw.text(text_position, text, font=font, fill="white")
# Save the composite image (optional).
output_path = os.path.join(os.getcwd(), "output.png")
try:
canvas.save(output_path)
except Exception as e:
# If saving fails, we still return the canvas.
pass
return canvas
# -------------------------------
# Setup for the smolagents CodeAgent and GradioUI
# -------------------------------
final_answer = FinalAnswerTool()
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud',
custom_role_conversions=None,
)
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[
final_answer,
my_custom_tool,
get_current_time_in_timezone,
image_generation_tool,
simulate_penalty
],
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name="PenaltyGPT",
description="An LLM-based football penalty game where the agent determines shot parameters and calls the simulate_penalty function.",
prompt_templates=prompt_templates
)
GradioUI(agent).launch()