Erfan commited on
Commit
20f10c1
·
1 Parent(s): 85ba1bc

Fix gradio UI and add image saving tool and .gitignore

Browse files
Files changed (7) hide show
  1. .gitattributes +11 -0
  2. .gitignore +35 -0
  3. Gradio_UI.py +21 -4
  4. app.py +52 -42
  5. prompts.yaml +1 -5
  6. tools/save_image.py +83 -0
  7. tools/web_search.py +38 -21
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
37
+ # Images (store in Git LFS)
38
+ *.png filter=lfs diff=lfs merge=lfs -text
39
+ *.jpg filter=lfs diff=lfs merge=lfs -text
40
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
41
+ *.gif filter=lfs diff=lfs merge=lfs -text
42
+ *.webp filter=lfs diff=lfs merge=lfs -text
43
+ *.bmp filter=lfs diff=lfs merge=lfs -text
44
+ *.tif filter=lfs diff=lfs merge=lfs -text
45
+ *.tiff filter=lfs diff=lfs merge=lfs -text
46
+ *.ico filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual environments
7
+ venv/
8
+ .venv/
9
+ env/
10
+ .env/
11
+
12
+ # Packaging/build
13
+ build/
14
+ dist/
15
+ *.egg-info/
16
+
17
+ # Tooling caches
18
+ .pytest_cache/
19
+ .mypy_cache/
20
+ .ruff_cache/
21
+ .coverage
22
+ coverage.xml
23
+
24
+ # OS/editor
25
+ .DS_Store
26
+ Thumbs.db
27
+ .idea/
28
+ .vscode/
29
+
30
+ # Gradio runtime artifacts
31
+ .gradio/
32
+
33
+ # App outputs
34
+ generated_images/
35
+ uploads/
Gradio_UI.py CHANGED
@@ -25,7 +25,7 @@ from typing import Optional
25
  import requests
26
  from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
27
  from smolagents.agents import ActionStep, MultiStepAgent
28
- from smolagents.memory import MemoryStep
29
  from smolagents.utils import _is_package_available
30
 
31
 
@@ -215,10 +215,28 @@ def stream_to_gradio(agent, task: str, reset_agent_memory: bool = False, additio
215
  for message in pull_messages_from_step(step_log):
216
  yield message
217
 
218
- final_answer = handle_agent_output_types(step_log)
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  if isinstance(final_answer, AgentText):
221
- yield gr.ChatMessage(role="assistant", content=f"**Final answer:**\n{final_answer.to_string()}\n")
 
 
 
 
 
 
 
222
 
223
  elif isinstance(final_answer, AgentImage):
224
  img_path = _save_agent_image(final_answer)
@@ -314,7 +332,6 @@ class GradioUI:
314
  file_uploads_log = gr.State([])
315
  chatbot = gr.Chatbot(
316
  label="Agent",
317
- type="messages",
318
  avatar_images=(
319
  None,
320
  "https://huggingface.co/datasets/agents-course/course-images/resolve/main/en/communication/Alfred.png",
 
25
  import requests
26
  from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
27
  from smolagents.agents import ActionStep, MultiStepAgent
28
+ from smolagents.memory import FinalAnswerStep, MemoryStep
29
  from smolagents.utils import _is_package_available
30
 
31
 
 
215
  for message in pull_messages_from_step(step_log):
216
  yield message
217
 
218
+ raw_final_answer = step_log.final_answer if isinstance(step_log, FinalAnswerStep) else step_log
219
+
220
+ # If a tool returns a local image path (e.g. via `save_image`), render it inline in the chat.
221
+ if isinstance(raw_final_answer, str):
222
+ candidate_path = raw_final_answer.strip()
223
+ if candidate_path and os.path.exists(candidate_path):
224
+ mime_type, _ = mimetypes.guess_type(candidate_path)
225
+ if mime_type and mime_type.startswith("image/"):
226
+ yield gr.ChatMessage(role="assistant", content={"path": candidate_path, "mime_type": mime_type})
227
+ return
228
+
229
+ final_answer = handle_agent_output_types(raw_final_answer)
230
 
231
  if isinstance(final_answer, AgentText):
232
+ # If the text is actually a local image path, render the image.
233
+ text = final_answer.to_string().strip()
234
+ if text and os.path.exists(text):
235
+ mime_type, _ = mimetypes.guess_type(text)
236
+ if mime_type and mime_type.startswith("image/"):
237
+ yield gr.ChatMessage(role="assistant", content={"path": text, "mime_type": mime_type})
238
+ return
239
+ yield gr.ChatMessage(role="assistant", content=f"**Final answer:**\n{text}\n")
240
 
241
  elif isinstance(final_answer, AgentImage):
242
  img_path = _save_agent_image(final_answer)
 
332
  file_uploads_log = gr.State([])
333
  chatbot = gr.Chatbot(
334
  label="Agent",
 
335
  avatar_images=(
336
  None,
337
  "https://huggingface.co/datasets/agents-course/course-images/resolve/main/en/communication/Alfred.png",
app.py CHANGED
@@ -1,9 +1,11 @@
1
- from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool
2
  import datetime
3
  import requests
4
  import pytz
5
  import yaml
6
  from tools.final_answer import FinalAnswerTool
 
 
7
 
8
  from Gradio_UI import GradioUI
9
 
@@ -15,7 +17,7 @@ def name_meaning(name: str) -> str:
15
  name: A name to look up.
16
  """
17
  return (
18
- f"CALL DuckDuckGoSearchTool with query: '{name} name meaning origin'. "
19
  f"Prefer sources like BehindTheName, Nameberry, Oxford Reference, Britannica. "
20
  f"Return: origin, meaning, variants, and 1-2 links."
21
  )
@@ -76,43 +78,51 @@ def get_weather(city: str) -> str:
76
  return f"Weather lookup failed: {str(e)}"
77
 
78
 
79
- final_answer = FinalAnswerTool()
80
-
81
- # If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
82
- # model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud'
83
-
84
- model = HfApiModel(
85
- max_tokens=2096,
86
- temperature=0.5,
87
- model_id='Qwen/Qwen2.5-Coder-32B-Instruct',# it is possible that this model may be overloaded
88
- custom_role_conversions=None,
89
- )
90
-
91
-
92
- # Import tool from Hub
93
- image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
94
-
95
- with open("prompts.yaml", 'r') as stream:
96
- prompt_templates = yaml.safe_load(stream)
97
-
98
- agent = CodeAgent(
99
- model=model,
100
- tools=[
101
- final_answer,
102
- DuckDuckGoSearchTool(), # Web search
103
- image_generation_tool, # HF Hub tool
104
- name_meaning, # Your custom tool
105
- get_current_time_in_timezone,
106
- get_weather # Extra creative tool!
107
- ],
108
- max_steps=6,
109
- verbosity_level=1,
110
- grammar=None,
111
- planning_interval=None,
112
- name=None,
113
- description=None,
114
- prompt_templates=prompt_templates
115
- )
116
-
117
-
118
- GradioUI(agent).launch()
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, HfApiModel, load_tool, tool
2
  import datetime
3
  import requests
4
  import pytz
5
  import yaml
6
  from tools.final_answer import FinalAnswerTool
7
+ from tools.save_image import SaveImageTool
8
+ from tools.web_search import WebSearchTool
9
 
10
  from Gradio_UI import GradioUI
11
 
 
17
  name: A name to look up.
18
  """
19
  return (
20
+ f"CALL web_search with query: '{name} name meaning origin'. "
21
  f"Prefer sources like BehindTheName, Nameberry, Oxford Reference, Britannica. "
22
  f"Return: origin, meaning, variants, and 1-2 links."
23
  )
 
78
  return f"Weather lookup failed: {str(e)}"
79
 
80
 
81
+ def build_agent() -> CodeAgent:
82
+ final_answer = FinalAnswerTool()
83
+ save_image = SaveImageTool()
84
+ web_search = WebSearchTool()
85
+
86
+ # If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
87
+ # model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud'
88
+ model = HfApiModel(
89
+ max_tokens=2096,
90
+ temperature=0.5,
91
+ model_id="Qwen/Qwen2.5-Coder-32B-Instruct", # it is possible that this model may be overloaded
92
+ custom_role_conversions=None,
93
+ )
94
+
95
+ # Import tool from Hub
96
+ image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
97
+
98
+ with open("prompts.yaml", "r") as stream:
99
+ prompt_templates = yaml.safe_load(stream)
100
+
101
+ return CodeAgent(
102
+ model=model,
103
+ tools=[
104
+ final_answer,
105
+ save_image,
106
+ web_search, # Web search
107
+ image_generation_tool, # HF Hub tool
108
+ name_meaning, # Your custom tool
109
+ get_current_time_in_timezone,
110
+ get_weather, # Extra creative tool!
111
+ ],
112
+ max_steps=6,
113
+ verbosity_level=1,
114
+ grammar=None,
115
+ planning_interval=None,
116
+ name=None,
117
+ description=None,
118
+ prompt_templates=prompt_templates,
119
+ )
120
+
121
+
122
+ def main():
123
+ agent = build_agent()
124
+ GradioUI(agent).launch()
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
prompts.yaml CHANGED
@@ -24,12 +24,8 @@
24
  Thought: I will now generate an image showcasing the oldest person, save it locally, and return the path.
25
  Code:
26
  ```py
27
- import os
28
- os.makedirs("generated_images", exist_ok=True)
29
-
30
  image = image_generator(prompt="A portrait of John Doe, a 55-year-old man living in Canada.")
31
- path = "generated_images/john_doe.png"
32
- image.save(path)
33
 
34
  final_answer(path)
35
  ```<end_code>
 
24
  Thought: I will now generate an image showcasing the oldest person, save it locally, and return the path.
25
  Code:
26
  ```py
 
 
 
27
  image = image_generator(prompt="A portrait of John Doe, a 55-year-old man living in Canada.")
28
+ path = save_image(image=image, filename="john_doe.png")
 
29
 
30
  final_answer(path)
31
  ```<end_code>
tools/save_image.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ import shutil
5
+ import uuid
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from typing import Any, Optional
9
+
10
+ from smolagents.tools import Tool
11
+
12
+
13
+ class SaveImageTool(Tool):
14
+ name = "save_image"
15
+ description = (
16
+ "Save an image to the local `generated_images/` folder and return the saved file path. "
17
+ "Use this instead of importing `os` for filesystem operations."
18
+ )
19
+ inputs = {
20
+ "image": {"type": "any", "description": "An image object (PIL, AgentImage), bytes, or an existing file path."},
21
+ "filename": {
22
+ "type": "string",
23
+ "description": "Optional output filename (e.g. `cat.png`). Defaults to a random `.png` name.",
24
+ "nullable": True,
25
+ },
26
+ }
27
+ output_type = "string"
28
+
29
+ def forward(self, image: Any, filename: Optional[str] = None) -> str:
30
+ base_dir = Path("generated_images")
31
+ base_dir.mkdir(parents=True, exist_ok=True)
32
+ base_dir_resolved = base_dir.resolve()
33
+
34
+ safe_name = self._sanitize_filename(filename) if filename else f"image_{uuid.uuid4().hex[:8]}.png"
35
+ out_path = (base_dir / safe_name).resolve()
36
+ if not out_path.is_relative_to(base_dir_resolved):
37
+ raise ValueError("Refusing to write outside `generated_images/`.")
38
+
39
+ # If `image` is already a path on disk, just copy it.
40
+ if isinstance(image, (str, Path)):
41
+ src = Path(image)
42
+ if src.exists() and src.is_file():
43
+ shutil.copyfile(src, out_path)
44
+ return str(out_path)
45
+
46
+ pil_img = self._to_pil(image)
47
+ pil_img.save(out_path)
48
+ return str(out_path)
49
+
50
+ @staticmethod
51
+ def _sanitize_filename(filename: str) -> str:
52
+ name = Path(filename).name # drop any path parts
53
+ name = re.sub(r"[^A-Za-z0-9._-]", "_", name).strip("._")
54
+ if not name:
55
+ name = f"image_{uuid.uuid4().hex[:8]}.png"
56
+ if "." not in name:
57
+ name += ".png"
58
+ return name
59
+
60
+ @staticmethod
61
+ def _to_pil(image: Any):
62
+ from PIL import Image
63
+
64
+ if hasattr(image, "save"):
65
+ return image
66
+
67
+ if hasattr(image, "to_pil"):
68
+ pil = image.to_pil()
69
+ if pil is not None and hasattr(pil, "save"):
70
+ return pil
71
+
72
+ if isinstance(image, (bytes, bytearray)):
73
+ return Image.open(BytesIO(image))
74
+
75
+ if hasattr(image, "to_string"):
76
+ as_str = image.to_string()
77
+ if isinstance(as_str, str):
78
+ p = Path(as_str)
79
+ if p.exists() and p.is_file():
80
+ return Image.open(p)
81
+
82
+ raise TypeError(f"Unsupported image type for saving: {type(image)}")
83
+
tools/web_search.py CHANGED
@@ -1,27 +1,44 @@
 
 
1
  from typing import Any, Optional
 
2
  from smolagents.tools import Tool
3
- import duckduckgo_search
4
 
5
- class DuckDuckGoSearchTool(Tool):
 
6
  name = "web_search"
7
- description = "Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."
8
- inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}}
9
- output_type = "string"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def __init__(self, max_results=10, **kwargs):
12
- super().__init__()
13
- self.max_results = max_results
14
  try:
15
- from duckduckgo_search import DDGS
16
- except ImportError as e:
17
- raise ImportError(
18
- "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
19
- ) from e
20
- self.ddgs = DDGS(**kwargs)
21
-
22
- def forward(self, query: str) -> str:
23
- results = self.ddgs.text(query, max_results=self.max_results)
24
- if len(results) == 0:
25
- raise Exception("No results found! Try a less restrictive/shorter query.")
26
- postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
27
- return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
 
 
 
 
1
+ from __future__ import annotations
2
+
3
  from typing import Any, Optional
4
+
5
  from smolagents.tools import Tool
 
6
 
7
+
8
+ class WebSearchTool(Tool):
9
  name = "web_search"
10
+ description = "Search the web with DuckDuckGo and return a short list of results."
11
+ inputs = {
12
+ "query": {"type": "string", "description": "Search query."},
13
+ "max_results": {
14
+ "type": "integer",
15
+ "description": "Maximum number of results to return (1-10). Defaults to 5.",
16
+ "nullable": True,
17
+ },
18
+ }
19
+ output_type = "any"
20
+
21
+ def forward(self, query: str, max_results: Optional[int] = 5) -> Any:
22
+ if not isinstance(query, str) or not query.strip():
23
+ raise ValueError("`query` must be a non-empty string.")
24
+
25
+ limit = 5 if max_results is None else int(max_results)
26
+ limit = max(1, min(10, limit))
27
 
 
 
 
28
  try:
29
+ from ddgs import DDGS # type: ignore
30
+ except Exception as e: # pragma: no cover
31
+ raise ModuleNotFoundError("Missing dependency: `ddgs` (pip install ddgs).") from e
32
+
33
+ results: list[dict[str, str]] = []
34
+ with DDGS() as ddgs:
35
+ for r in ddgs.text(query, max_results=limit):
36
+ title = (r.get("title") or "").strip()
37
+ url = (r.get("href") or r.get("url") or "").strip()
38
+ snippet = (r.get("body") or r.get("snippet") or "").strip()
39
+ if not (title or url or snippet):
40
+ continue
41
+ results.append({"title": title, "url": url, "snippet": snippet})
42
+
43
+ return results
44
+