""" Image agent backend — multimodal agent with HuggingFace image generation tools. Uses the same tool-calling loop pattern as agent.py: LLM call → parse tool_calls → execute → update history → repeat Key difference: maintains a figure store (Dict[str, str]) mapping names like "figure_T1_1" to base64 data, so the VLM can reference images across tool calls without passing huge base64 strings in arguments. """ import base64 import json import logging import re from typing import List, Dict, Optional from .tools import ( generate_image, edit_image, read_image, save_image, execute_generate_image, execute_edit_image, execute_read_image, ) logger = logging.getLogger(__name__) TOOLS = [generate_image, edit_image, read_image, save_image] # Max dimension for images sent to the VLM context (keeps token count manageable) VLM_IMAGE_MAX_DIM = 512 VLM_IMAGE_JPEG_QUALITY = 70 def resize_image_for_vlm(base64_png: str) -> str: """Resize and compress an image for VLM context to avoid token overflow. Takes a full-res base64 PNG and returns a smaller base64 JPEG thumbnail that fits within VLM_IMAGE_MAX_DIM on its longest side. """ try: from PIL import Image import io as _io img_bytes = base64.b64decode(base64_png) img = Image.open(_io.BytesIO(img_bytes)) # Resize if larger than max dimension if max(img.size) > VLM_IMAGE_MAX_DIM: img.thumbnail((VLM_IMAGE_MAX_DIM, VLM_IMAGE_MAX_DIM), Image.LANCZOS) # Convert to RGB (JPEG doesn't support alpha) if img.mode in ("RGBA", "P"): img = img.convert("RGB") # Save as JPEG for much smaller base64 buffer = _io.BytesIO() img.save(buffer, format="JPEG", quality=VLM_IMAGE_JPEG_QUALITY) return base64.b64encode(buffer.getvalue()).decode("utf-8") except Exception as e: logger.error(f"Failed to resize image for VLM: {e}") # Fall back to original — better to try than to lose the image entirely return base64_png MAX_TURNS = 20 def execute_tool(tool_name: str, args: dict, hf_token: str, image_store: dict, image_counter: int, default_gen_model: str = None, default_edit_model: str = None, files_root: str = None, image_prefix: str = "figure_") -> dict: """ Execute a tool by name and return result dict. Returns: dict with keys: - "content": str result for the LLM - "image": optional base64 PNG - "image_name": optional image reference name (e.g., "image_1") - "display": dict with display-friendly data for frontend - "image_counter": updated counter """ if tool_name == "generate_image": prompt = args.get("prompt", "") model = args.get("model") or default_gen_model or "black-forest-labs/FLUX.1-schnell" base64_png, error = execute_generate_image(prompt, hf_token, model) if base64_png: image_counter += 1 name = f"{image_prefix}{image_counter}" image_store[name] = {"type": "png", "data": base64_png} return { "content": f"Image generated successfully as '{name}'. The image is attached.", "image": base64_png, "image_name": name, "display": {"type": "generate", "prompt": prompt, "model": model, "image_name": name}, "image_counter": image_counter, } else: return { "content": f"Failed to generate image: {error}", "display": {"type": "generate_error", "prompt": prompt}, "image_counter": image_counter, } elif tool_name == "edit_image": prompt = args.get("prompt", "") source = args.get("source", "") model = args.get("model") or default_edit_model or "black-forest-labs/FLUX.1-Kontext-dev" # Resolve source: image store reference, URL, or local path source_bytes = None if source in image_store: source_bytes = base64.b64decode(image_store[source]["data"]) else: source_base64 = execute_read_image(source, files_root=files_root) if source_base64: source_bytes = base64.b64decode(source_base64) if source_bytes is None: return { "content": f"Could not resolve image source '{source}'. Use a URL or a reference from a previous tool call (e.g., 'figure_T1_1').", "display": {"type": "edit_error", "source": source}, "image_counter": image_counter, } base64_png, error = execute_edit_image(prompt, source_bytes, hf_token, model) if base64_png: image_counter += 1 name = f"{image_prefix}{image_counter}" image_store[name] = {"type": "png", "data": base64_png} return { "content": f"Image edited successfully as '{name}'. The image is attached.", "image": base64_png, "image_name": name, "display": {"type": "edit", "prompt": prompt, "source": source, "model": model, "image_name": name}, "image_counter": image_counter, } else: return { "content": f"Failed to edit image: {error}", "display": {"type": "edit_error", "source": source}, "image_counter": image_counter, } elif tool_name == "save_image": source = args.get("source", "") filename = args.get("filename", "image.png") # Ensure .png extension if not filename.lower().endswith(".png"): filename += ".png" # Resolve source from image store or URL image_data = None if source in image_store: image_data = base64.b64decode(image_store[source]["data"]) else: source_base64 = execute_read_image(source, files_root=files_root) if source_base64: image_data = base64.b64decode(source_base64) if image_data is None: return { "content": f"Could not resolve image source '{source}'. Use a reference (e.g., 'figure_T1_1') or a URL.", "display": {"type": "save_error", "source": source}, "image_counter": image_counter, } # Save to files_root import os save_dir = files_root or "." os.makedirs(save_dir, exist_ok=True) # Sanitize filename filename = os.path.basename(filename) save_path = os.path.join(save_dir, filename) with open(save_path, "wb") as f: f.write(image_data) # Include base64 so frontend can show a preview of the saved image saved_base64 = base64.b64encode(image_data).decode("utf-8") return { "content": f"Image saved as '{filename}'.", "image": saved_base64, "display": {"type": "save_image", "filename": filename, "source": source}, "image_counter": image_counter, } elif tool_name in ("read_image", "read_image_url"): source = args.get("source") or args.get("url", "") base64_png = execute_read_image(source, files_root=files_root) if base64_png: image_counter += 1 name = f"{image_prefix}{image_counter}" image_store[name] = {"type": "png", "data": base64_png} return { "content": f"Image loaded successfully as '{name}'. The image is attached.", "image": base64_png, "image_name": name, "display": {"type": "read_image", "url": source, "image_name": name}, "image_counter": image_counter, } else: # Provide more specific error for SVG files is_svg = source.lower().endswith(".svg") or "/svg" in source.lower() if is_svg: error_msg = f"Failed to load image from '{source}'. SVG format is not supported — only raster formats (PNG, JPEG, GIF, WebP, BMP) are accepted. Ask the user for a raster version of the image." else: error_msg = f"Failed to load image from '{source}'. Check that the path or URL is correct and that it is a raster image (PNG, JPEG, GIF, WebP, BMP)." return { "content": error_msg, "display": {"type": "read_image_error", "url": source}, "image_counter": image_counter, } return { "content": f"Unknown tool: {tool_name}", "display": {"type": "error"}, "image_counter": image_counter, } def stream_image_execution( client, model: str, messages: List[Dict], hf_token: str, image_gen_model: Optional[str] = None, image_edit_model: Optional[str] = None, extra_params: Optional[Dict] = None, abort_event=None, files_root: str = None, multimodal: bool = False, tab_id: str = "0", image_store: Optional[Dict[str, dict]] = None, image_counter: int = 0, ): """ Run the image agent tool-calling loop. Yields dicts with SSE event types: - thinking: { content } - content: { content } - tool_start: { tool, args } - tool_result: { tool, result, image? } - result_preview: { content } - result: { content, figures? } - generating: {} - retry: { attempt, max_attempts, delay, message } - error: { content } - done: {} """ from .agents import call_llm turns = 0 done = False image_prefix = f"figure_T{tab_id}_" # Use provided persistent store, or create a local one as fallback if image_store is None: image_store = {} result_sent = False debug_call_number = 0 while not done and turns < MAX_TURNS: # Check abort before each turn if abort_event and abort_event.is_set(): yield {"type": "aborted"} return turns += 1 # LLM call with retries and debug events response = None for event in call_llm(client, model, messages, tools=TOOLS, extra_params=extra_params, abort_event=abort_event, call_number=debug_call_number): if "_response" in event: response = event["_response"] debug_call_number = event["_call_number"] else: yield event if event.get("type") in ("error", "aborted"): return if response is None: return # --- Parse response --- assistant_message = response.choices[0].message content = assistant_message.content or "" tool_calls = assistant_message.tool_calls or [] # Check for tags result_match = re.search(r'(.*?)', content, re.DOTALL | re.IGNORECASE) result_content = None thinking_content = content if result_match: result_content = result_match.group(1).strip() thinking_content = re.sub(r'.*?', '', content, flags=re.DOTALL | re.IGNORECASE).strip() # Send thinking/content if thinking_content.strip(): if tool_calls: yield {"type": "thinking", "content": thinking_content} else: yield {"type": "content", "content": thinking_content} # Send result preview if result_content: figures = dict(image_store) yield {"type": "result_preview", "content": result_content, "figures": figures} # --- Handle tool calls --- if tool_calls: for tool_call in tool_calls: # Check abort between tool calls if abort_event and abort_event.is_set(): yield {"type": "aborted"} return func_name = tool_call.function.name # Parse arguments try: args = json.loads(tool_call.function.arguments) except json.JSONDecodeError as e: output = f"Error parsing arguments: {e}" messages.append({ "role": "assistant", "content": content, "tool_calls": [{"id": tool_call.id, "type": "function", "function": {"name": func_name, "arguments": tool_call.function.arguments}}] }) messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": output}) yield {"type": "error", "content": output} continue # Signal tool start yield { "type": "tool_start", "tool": func_name, "args": args, "tool_call_id": tool_call.id, "arguments": tool_call.function.arguments, "thinking": content, } # Execute tool result = execute_tool(func_name, args, hf_token, image_store, image_counter, default_gen_model=image_gen_model, default_edit_model=image_edit_model, files_root=files_root, image_prefix=image_prefix) image_counter = result.get("image_counter", image_counter) # Build tool response content for LLM if result.get("image") and multimodal: vlm_image = resize_image_for_vlm(result["image"]) tool_response_content = [ {"type": "text", "text": result["content"]}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{vlm_image}"}} ] else: tool_response_content = result["content"] # Add to message history messages.append({ "role": "assistant", "content": content, "tool_calls": [{"id": tool_call.id, "type": "function", "function": {"name": func_name, "arguments": tool_call.function.arguments}}] }) messages.append({ "role": "tool", "tool_call_id": tool_call.id, "content": tool_response_content }) # Signal tool result to frontend tool_result_event = { "type": "tool_result", "tool": func_name, "tool_call_id": tool_call.id, "result": result.get("display", {}), "response": result.get("content", ""), } if result.get("image"): tool_result_event["image"] = result["image"] if result.get("image_name"): tool_result_event["image_name"] = result["image_name"] yield tool_result_event else: # No tool calls — we're done messages.append({"role": "assistant", "content": content}) done = True # Send result if found if result_content: figures = dict(image_store) yield {"type": "result", "content": result_content, "figures": figures} result_sent = True # Signal between-turn processing if not done: yield {"type": "generating"} # If agent finished without a , nudge it for one if not result_sent: from .agents import nudge_for_result nudge_produced_result = False figures = dict(image_store) for event in nudge_for_result(client, model, messages, extra_params=extra_params, extra_result_data={"figures": figures}, call_number=debug_call_number): yield event if event.get("type") == "result": nudge_produced_result = True # Final fallback: synthesize a result with all figures if not nudge_produced_result: fallback_parts = [f"<{name}>" for name in image_store] figures = dict(image_store) yield {"type": "result", "content": "\n\n".join(fallback_parts), "figures": figures} yield {"type": "done"}