Spaces:
Running
Running
| """ | |
| Image agent backend — multimodal agent with HuggingFace image generation tools. | |
| Uses the same tool-calling loop pattern as agent.py: | |
| LLM call → parse tool_calls → execute → update history → repeat | |
| Key difference: maintains a figure store (Dict[str, str]) mapping names like | |
| "figure_T1_1" to base64 data, so the VLM can reference images across tool calls | |
| without passing huge base64 strings in arguments. | |
| """ | |
| import base64 | |
| import json | |
| import logging | |
| import re | |
| from typing import List, Dict, Optional | |
| from .tools import ( | |
| generate_image, edit_image, read_image, save_image, | |
| execute_generate_image, execute_edit_image, execute_read_image, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| TOOLS = [generate_image, edit_image, read_image, save_image] | |
| # Max dimension for images sent to the VLM context (keeps token count manageable) | |
| VLM_IMAGE_MAX_DIM = 512 | |
| VLM_IMAGE_JPEG_QUALITY = 70 | |
| def resize_image_for_vlm(base64_png: str) -> str: | |
| """Resize and compress an image for VLM context to avoid token overflow. | |
| Takes a full-res base64 PNG and returns a smaller base64 JPEG thumbnail | |
| that fits within VLM_IMAGE_MAX_DIM on its longest side. | |
| """ | |
| try: | |
| from PIL import Image | |
| import io as _io | |
| img_bytes = base64.b64decode(base64_png) | |
| img = Image.open(_io.BytesIO(img_bytes)) | |
| # Resize if larger than max dimension | |
| if max(img.size) > VLM_IMAGE_MAX_DIM: | |
| img.thumbnail((VLM_IMAGE_MAX_DIM, VLM_IMAGE_MAX_DIM), Image.LANCZOS) | |
| # Convert to RGB (JPEG doesn't support alpha) | |
| if img.mode in ("RGBA", "P"): | |
| img = img.convert("RGB") | |
| # Save as JPEG for much smaller base64 | |
| buffer = _io.BytesIO() | |
| img.save(buffer, format="JPEG", quality=VLM_IMAGE_JPEG_QUALITY) | |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| except Exception as e: | |
| logger.error(f"Failed to resize image for VLM: {e}") | |
| # Fall back to original — better to try than to lose the image entirely | |
| return base64_png | |
| MAX_TURNS = 20 | |
| def execute_tool(tool_name: str, args: dict, hf_token: str, image_store: dict, image_counter: int, default_gen_model: str = None, default_edit_model: str = None, files_root: str = None, image_prefix: str = "figure_") -> dict: | |
| """ | |
| Execute a tool by name and return result dict. | |
| Returns: | |
| dict with keys: | |
| - "content": str result for the LLM | |
| - "image": optional base64 PNG | |
| - "image_name": optional image reference name (e.g., "image_1") | |
| - "display": dict with display-friendly data for frontend | |
| - "image_counter": updated counter | |
| """ | |
| if tool_name == "generate_image": | |
| prompt = args.get("prompt", "") | |
| model = args.get("model") or default_gen_model or "black-forest-labs/FLUX.1-schnell" | |
| base64_png, error = execute_generate_image(prompt, hf_token, model) | |
| if base64_png: | |
| image_counter += 1 | |
| name = f"{image_prefix}{image_counter}" | |
| image_store[name] = {"type": "png", "data": base64_png} | |
| return { | |
| "content": f"Image generated successfully as '{name}'. The image is attached.", | |
| "image": base64_png, | |
| "image_name": name, | |
| "display": {"type": "generate", "prompt": prompt, "model": model, "image_name": name}, | |
| "image_counter": image_counter, | |
| } | |
| else: | |
| return { | |
| "content": f"Failed to generate image: {error}", | |
| "display": {"type": "generate_error", "prompt": prompt}, | |
| "image_counter": image_counter, | |
| } | |
| elif tool_name == "edit_image": | |
| prompt = args.get("prompt", "") | |
| source = args.get("source", "") | |
| model = args.get("model") or default_edit_model or "black-forest-labs/FLUX.1-Kontext-dev" | |
| # Resolve source: image store reference, URL, or local path | |
| source_bytes = None | |
| if source in image_store: | |
| source_bytes = base64.b64decode(image_store[source]["data"]) | |
| else: | |
| source_base64 = execute_read_image(source, files_root=files_root) | |
| if source_base64: | |
| source_bytes = base64.b64decode(source_base64) | |
| if source_bytes is None: | |
| return { | |
| "content": f"Could not resolve image source '{source}'. Use a URL or a reference from a previous tool call (e.g., 'figure_T1_1').", | |
| "display": {"type": "edit_error", "source": source}, | |
| "image_counter": image_counter, | |
| } | |
| base64_png, error = execute_edit_image(prompt, source_bytes, hf_token, model) | |
| if base64_png: | |
| image_counter += 1 | |
| name = f"{image_prefix}{image_counter}" | |
| image_store[name] = {"type": "png", "data": base64_png} | |
| return { | |
| "content": f"Image edited successfully as '{name}'. The image is attached.", | |
| "image": base64_png, | |
| "image_name": name, | |
| "display": {"type": "edit", "prompt": prompt, "source": source, "model": model, "image_name": name}, | |
| "image_counter": image_counter, | |
| } | |
| else: | |
| return { | |
| "content": f"Failed to edit image: {error}", | |
| "display": {"type": "edit_error", "source": source}, | |
| "image_counter": image_counter, | |
| } | |
| elif tool_name == "save_image": | |
| source = args.get("source", "") | |
| filename = args.get("filename", "image.png") | |
| # Ensure .png extension | |
| if not filename.lower().endswith(".png"): | |
| filename += ".png" | |
| # Resolve source from image store or URL | |
| image_data = None | |
| if source in image_store: | |
| image_data = base64.b64decode(image_store[source]["data"]) | |
| else: | |
| source_base64 = execute_read_image(source, files_root=files_root) | |
| if source_base64: | |
| image_data = base64.b64decode(source_base64) | |
| if image_data is None: | |
| return { | |
| "content": f"Could not resolve image source '{source}'. Use a reference (e.g., 'figure_T1_1') or a URL.", | |
| "display": {"type": "save_error", "source": source}, | |
| "image_counter": image_counter, | |
| } | |
| # Save to files_root | |
| import os | |
| save_dir = files_root or "." | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Sanitize filename | |
| filename = os.path.basename(filename) | |
| save_path = os.path.join(save_dir, filename) | |
| with open(save_path, "wb") as f: | |
| f.write(image_data) | |
| # Include base64 so frontend can show a preview of the saved image | |
| saved_base64 = base64.b64encode(image_data).decode("utf-8") | |
| return { | |
| "content": f"Image saved as '{filename}'.", | |
| "image": saved_base64, | |
| "display": {"type": "save_image", "filename": filename, "source": source}, | |
| "image_counter": image_counter, | |
| } | |
| elif tool_name in ("read_image", "read_image_url"): | |
| source = args.get("source") or args.get("url", "") | |
| base64_png = execute_read_image(source, files_root=files_root) | |
| if base64_png: | |
| image_counter += 1 | |
| name = f"{image_prefix}{image_counter}" | |
| image_store[name] = {"type": "png", "data": base64_png} | |
| return { | |
| "content": f"Image loaded successfully as '{name}'. The image is attached.", | |
| "image": base64_png, | |
| "image_name": name, | |
| "display": {"type": "read_image", "url": source, "image_name": name}, | |
| "image_counter": image_counter, | |
| } | |
| else: | |
| # Provide more specific error for SVG files | |
| is_svg = source.lower().endswith(".svg") or "/svg" in source.lower() | |
| if is_svg: | |
| error_msg = f"Failed to load image from '{source}'. SVG format is not supported — only raster formats (PNG, JPEG, GIF, WebP, BMP) are accepted. Ask the user for a raster version of the image." | |
| else: | |
| error_msg = f"Failed to load image from '{source}'. Check that the path or URL is correct and that it is a raster image (PNG, JPEG, GIF, WebP, BMP)." | |
| return { | |
| "content": error_msg, | |
| "display": {"type": "read_image_error", "url": source}, | |
| "image_counter": image_counter, | |
| } | |
| return { | |
| "content": f"Unknown tool: {tool_name}", | |
| "display": {"type": "error"}, | |
| "image_counter": image_counter, | |
| } | |
| def stream_image_execution( | |
| client, | |
| model: str, | |
| messages: List[Dict], | |
| hf_token: str, | |
| image_gen_model: Optional[str] = None, | |
| image_edit_model: Optional[str] = None, | |
| extra_params: Optional[Dict] = None, | |
| abort_event=None, | |
| files_root: str = None, | |
| multimodal: bool = False, | |
| tab_id: str = "0", | |
| image_store: Optional[Dict[str, dict]] = None, | |
| image_counter: int = 0, | |
| ): | |
| """ | |
| Run the image agent tool-calling loop. | |
| Yields dicts with SSE event types: | |
| - thinking: { content } | |
| - content: { content } | |
| - tool_start: { tool, args } | |
| - tool_result: { tool, result, image? } | |
| - result_preview: { content } | |
| - result: { content, figures? } | |
| - generating: {} | |
| - retry: { attempt, max_attempts, delay, message } | |
| - error: { content } | |
| - done: {} | |
| """ | |
| from .agents import call_llm | |
| turns = 0 | |
| done = False | |
| image_prefix = f"figure_T{tab_id}_" | |
| # Use provided persistent store, or create a local one as fallback | |
| if image_store is None: | |
| image_store = {} | |
| result_sent = False | |
| debug_call_number = 0 | |
| while not done and turns < MAX_TURNS: | |
| # Check abort before each turn | |
| if abort_event and abort_event.is_set(): | |
| yield {"type": "aborted"} | |
| return | |
| turns += 1 | |
| # LLM call with retries and debug events | |
| response = None | |
| for event in call_llm(client, model, messages, tools=TOOLS, extra_params=extra_params, abort_event=abort_event, call_number=debug_call_number): | |
| if "_response" in event: | |
| response = event["_response"] | |
| debug_call_number = event["_call_number"] | |
| else: | |
| yield event | |
| if event.get("type") in ("error", "aborted"): | |
| return | |
| if response is None: | |
| return | |
| # --- Parse response --- | |
| assistant_message = response.choices[0].message | |
| content = assistant_message.content or "" | |
| tool_calls = assistant_message.tool_calls or [] | |
| # Check for <result> tags | |
| result_match = re.search(r'<result>(.*?)</result>', content, re.DOTALL | re.IGNORECASE) | |
| result_content = None | |
| thinking_content = content | |
| if result_match: | |
| result_content = result_match.group(1).strip() | |
| thinking_content = re.sub(r'<result>.*?</result>', '', content, flags=re.DOTALL | re.IGNORECASE).strip() | |
| # Send thinking/content | |
| if thinking_content.strip(): | |
| if tool_calls: | |
| yield {"type": "thinking", "content": thinking_content} | |
| else: | |
| yield {"type": "content", "content": thinking_content} | |
| # Send result preview | |
| if result_content: | |
| figures = dict(image_store) | |
| yield {"type": "result_preview", "content": result_content, "figures": figures} | |
| # --- Handle tool calls --- | |
| if tool_calls: | |
| for tool_call in tool_calls: | |
| # Check abort between tool calls | |
| if abort_event and abort_event.is_set(): | |
| yield {"type": "aborted"} | |
| return | |
| func_name = tool_call.function.name | |
| # Parse arguments | |
| try: | |
| args = json.loads(tool_call.function.arguments) | |
| except json.JSONDecodeError as e: | |
| output = f"Error parsing arguments: {e}" | |
| messages.append({ | |
| "role": "assistant", | |
| "content": content, | |
| "tool_calls": [{"id": tool_call.id, "type": "function", "function": {"name": func_name, "arguments": tool_call.function.arguments}}] | |
| }) | |
| messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": output}) | |
| yield {"type": "error", "content": output} | |
| continue | |
| # Signal tool start | |
| yield { | |
| "type": "tool_start", | |
| "tool": func_name, | |
| "args": args, | |
| "tool_call_id": tool_call.id, | |
| "arguments": tool_call.function.arguments, | |
| "thinking": content, | |
| } | |
| # Execute tool | |
| result = execute_tool(func_name, args, hf_token, image_store, image_counter, default_gen_model=image_gen_model, default_edit_model=image_edit_model, files_root=files_root, image_prefix=image_prefix) | |
| image_counter = result.get("image_counter", image_counter) | |
| # Build tool response content for LLM | |
| if result.get("image") and multimodal: | |
| vlm_image = resize_image_for_vlm(result["image"]) | |
| tool_response_content = [ | |
| {"type": "text", "text": result["content"]}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{vlm_image}"}} | |
| ] | |
| else: | |
| tool_response_content = result["content"] | |
| # Add to message history | |
| messages.append({ | |
| "role": "assistant", | |
| "content": content, | |
| "tool_calls": [{"id": tool_call.id, "type": "function", "function": {"name": func_name, "arguments": tool_call.function.arguments}}] | |
| }) | |
| messages.append({ | |
| "role": "tool", | |
| "tool_call_id": tool_call.id, | |
| "content": tool_response_content | |
| }) | |
| # Signal tool result to frontend | |
| tool_result_event = { | |
| "type": "tool_result", | |
| "tool": func_name, | |
| "tool_call_id": tool_call.id, | |
| "result": result.get("display", {}), | |
| "response": result.get("content", ""), | |
| } | |
| if result.get("image"): | |
| tool_result_event["image"] = result["image"] | |
| if result.get("image_name"): | |
| tool_result_event["image_name"] = result["image_name"] | |
| yield tool_result_event | |
| else: | |
| # No tool calls — we're done | |
| messages.append({"role": "assistant", "content": content}) | |
| done = True | |
| # Send result if found | |
| if result_content: | |
| figures = dict(image_store) | |
| yield {"type": "result", "content": result_content, "figures": figures} | |
| result_sent = True | |
| # Signal between-turn processing | |
| if not done: | |
| yield {"type": "generating"} | |
| # If agent finished without a <result>, nudge it for one | |
| if not result_sent: | |
| from .agents import nudge_for_result | |
| nudge_produced_result = False | |
| figures = dict(image_store) | |
| for event in nudge_for_result(client, model, messages, extra_params=extra_params, extra_result_data={"figures": figures}, call_number=debug_call_number): | |
| yield event | |
| if event.get("type") == "result": | |
| nudge_produced_result = True | |
| # Final fallback: synthesize a result with all figures | |
| if not nudge_produced_result: | |
| fallback_parts = [f"<{name}>" for name in image_store] | |
| figures = dict(image_store) | |
| yield {"type": "result", "content": "\n\n".join(fallback_parts), "figures": figures} | |
| yield {"type": "done"} | |