whiran's picture
Update app.py
4aff01c verified
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, load_tool, tool
import datetime
import requests
import pytz
import yaml
import torch
from diffusers import StableDiffusionPipeline
from pathlib import Path
from typing import Optional
from tools.final_answer import FinalAnswerTool
from Gradio_UI import GradioUI
@tool
def generate_image_tool(
prompt: str,
model_name: str = "stabilityai/stable-diffusion-2-1",
num_inference_steps: int = 50,
output_path: str = "generated_image.png"
) -> str:
"""A tool for generating images from text prompts using Stable Diffusion.
Args:
prompt: Text description of the image (must be SFW)
model_name: AI model to use (default: stabilityai/stable-diffusion-2-1)
num_inference_steps: Quality steps (20-100)
output_path: Where to save the image
"""
# Safety check
unsafe_keywords = ["nude", "porn", "explicit", "adult", "nsfw"]
if any(kw in prompt.lower() for kw in unsafe_keywords):
return "Error: Content policy violation detected in prompt"
try:
pipe = StableDiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16
).to("cuda")
image = pipe(
prompt=prompt,
num_inference_steps=num_inference_steps
).images[0]
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
image.save(output_path)
return f"Image generated successfully at: {output_path}"
except Exception as e:
return f"Image generation failed: {str(e)}"
@tool
def get_current_time_in_timezone(timezone: str) -> str:
"""A tool that fetches the current local time in a specified timezone.
Args:
timezone: A valid timezone (e.g., 'America/New_York')
"""
try:
tz = pytz.timezone(timezone)
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"Current time in {timezone}: {local_time}"
except Exception as e:
return f"Error: {str(e)}"
final_answer = FinalAnswerTool()
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
custom_role_conversions=None,
)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[generate_image_tool, get_current_time_in_timezone, final_answer],
max_steps=6,
verbosity_level=1,
prompt_templates=prompt_templates,
grammar=None,
planning_interval=None,
name="Creative Assistant",
description="AI assistant capable of generating images and providing time information"
)
GradioUI(agent).launch()