Risalat commited on
Commit
d7c374a
·
verified ·
1 Parent(s): f1ead0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -35,8 +35,10 @@ from smolagents import tool
35
  # image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
36
  # Import tool from Hub
37
  image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
 
 
38
  @tool
39
- def image_generator(arg1: str):
40
  """
41
  A tool that generates an image based on a text prompt.
42
 
@@ -50,16 +52,28 @@ def image_generator(arg1: str):
50
  result = image_generation_tool(prompt=arg1)
51
 
52
  # If result is a PIL image, return it directly
53
- if hasattr(result, 'show') and callable(result.show):
54
- return result # Gradio will display this automatically
55
- elif isinstance(result, str):
56
- return f"Generated image URL: {result}"
 
 
 
 
 
 
57
  elif isinstance(result, dict) and "image_url" in result:
58
- return f"Generated image URL: {result['image_url']}"
59
- else:
60
- return f"Image generated: {result}"
 
 
 
 
61
  except Exception as e:
62
- return f"Error generating image: {str(e)}"
 
 
63
 
64
 
65
 
 
35
  # image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
36
  # Import tool from Hub
37
  image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
38
+ from PIL import Image
39
+
40
  @tool
41
+ def image_generator(arg1: str) -> Image.Image:
42
  """
43
  A tool that generates an image based on a text prompt.
44
 
 
52
  result = image_generation_tool(prompt=arg1)
53
 
54
  # If result is a PIL image, return it directly
55
+ if isinstance(result, Image.Image):
56
+ return result
57
+
58
+ # If result is a URL string
59
+ elif isinstance(result, str) and result.startswith("http"):
60
+ response = requests.get(result)
61
+ image = Image.open(BytesIO(response.content))
62
+ return image
63
+
64
+ # If result is a dictionary with an image URL
65
  elif isinstance(result, dict) and "image_url" in result:
66
+ image_url = result["image_url"]
67
+ response = requests.get(image_url)
68
+ image = Image.open(BytesIO(response.content))
69
+ return image
70
+
71
+ return Image.new("RGB", (256, 256), color="gray") # fallback dummy image
72
+
73
  except Exception as e:
74
+ # Show a red dummy image with error in metadata or raise it
75
+ return Image.new("RGB", (256, 256), color="red")
76
+
77
 
78
 
79