#!/usr/bin/env python # coding=utf-8 # Copyright 2024 The Footscray Coding Collective. All rights reserved. """ Zhou Protocol FLUX-LoRA Integration Tool This module provides a Smolagents Tool implementation for interacting with FLUX-LoRA-DLC API. It enables agents to generate high-quality images with customizable LoRA models. Usage: flux_tool = FluxLoRATool() agent = CodeAgent(tools=[flux_tool], ...) """ import logging import os import tempfile import uuid from dataclasses import dataclass from typing import Any, Dict, Optional # Third-party import requests from gradio_client import Client from PIL import Image from smolagents import Tool # ----------------------------------------------------------------------------- # CONSTANTS AND TYPE DEFINITIONS # ----------------------------------------------------------------------------- @dataclass class LoRAModelInfo: """Value object representing LoRA model information.""" name: str description: Optional[str] = None example_image_url: Optional[str] = None @dataclass class ImageGenerationResult: """Value object representing a generated image result.""" image_path: str seed: int metadata: Optional[Dict[str, Any]] = None # ----------------------------------------------------------------------------- # CORE TOOL IMPLEMENTATION # ----------------------------------------------------------------------------- class FluxLoRATool(Tool): """ Tool for generating images using FLUX-LoRA-DLC API. This tool implements the Zhou Protocol integration patterns to provide a clean, efficient interface for image generation using LoRA models. """ name = "flux_lora_generator" description = """ Generates high-quality images using FLUX-LoRA models. Can use custom LoRA models, adjust image parameters, and handle image inputs. """ inputs = { "prompt": { "type": "string", "description": "Detailed description of the desired image.", }, "image_input": { "type": "string", "description": "Optional URL or file path to input image for img2img generation.", "optional": True, }, "image_strength": { "type": "float", "description": "Strength of input image influence (0.0-1.0), where 1.0 maintains more of original image.", "optional": True, "default": 0.75, }, "cfg_scale": { "type": "float", "description": "Guidance scale for prompt adherence (1.0-30.0).", "optional": True, "default": 3.5, }, "steps": { "type": "integer", "description": "Number of sampling steps (10-100).", "optional": True, "default": 28, }, "seed": { "type": "integer", "description": "Random seed for reproducibility. Use -1 for random seed.", "optional": True, "default": -1, }, "width": { "type": "integer", "description": "Image width in pixels.", "optional": True, "default": 1024, }, "height": { "type": "integer", "description": "Image height in pixels.", "optional": True, "default": 1024, }, "lora_scale": { "type": "float", "description": "LoRA influence scale (0.0-1.0).", "optional": True, "default": 0.95, }, "custom_lora": { "type": "string", "description": "Custom LoRA model to use. Leave empty for default.", "optional": True, }, } output_type = "string" def __init__( self, api_url: str = "xkerser/FLUX-LoRA-DLC", image_save_dir: Optional[str] = None, connection_timeout: int = 60, verbose: bool = False, ): """ Initialize the FLUX-LoRA Tool with Zhou Protocol connection patterns. Args: api_url: URL or endpoint ID for the FLUX-LoRA-DLC API image_save_dir: Directory to save generated images (created if doesn't exist) connection_timeout: API connection timeout in seconds verbose: Enable detailed logging """ super().__init__() # Initialize logging self.logger = logging.getLogger("flux_lora_tool") self.logger.setLevel(logging.DEBUG if verbose else logging.INFO) # Set up client and storage directories self.api_url = api_url self.connection_timeout = connection_timeout self._client = None # Lazy initialization # Set up image storage directory self.image_save_dir = image_save_dir or os.path.join( tempfile.gettempdir(), "flux_lora_images" ) os.makedirs(self.image_save_dir, exist_ok=True) self.logger.info( f"FluxLoRATool initialized. Images will be saved to: {self.image_save_dir}" ) @property def client(self) -> Client: """ Get or initialize the Gradio client with proper connection handling. Returns: Initialized Gradio client Raises: ConnectionError: If client initialization fails """ if self._client is None: try: self._client = Client(self.api_url, timeout=self.connection_timeout) self.logger.debug(f"Gradio client initialized for: {self.api_url}") except Exception as e: error_msg = f"Failed to initialize FLUX-LoRA client: {str(e)}" self.logger.error(error_msg) raise ConnectionError(error_msg) from e return self._client def _validate_inputs(self, **kwargs) -> Dict[str, Any]: """ Validate and normalize input parameters with Zhou Protocol validation patterns. Args: **kwargs: Input parameters Returns: Validated and normalized parameters Raises: ValueError: If input validation fails """ validated = {} # Required parameter: prompt if not kwargs.get("prompt"): raise ValueError("Prompt is required for image generation") validated["prompt"] = kwargs["prompt"] # Image input handling if "image_input" in kwargs and kwargs["image_input"]: input_image = kwargs["image_input"] # Handle URL vs. local file if input_image.startswith(("http://", "https://")): # We'll need to download and process this validated["image_input"] = self._download_image(input_image) else: # Check if file exists if not os.path.exists(input_image): raise ValueError(f"Image file not found: {input_image}") validated["image_input"] = input_image # Numeric parameter validation with constraints numeric_params = { "image_strength": {"min": 0.0, "max": 1.0, "default": 0.75}, "cfg_scale": {"min": 1.0, "max": 30.0, "default": 3.5}, "steps": {"min": 10, "max": 100, "default": 28}, "width": {"min": 128, "max": 2048, "default": 1024}, "height": {"min": 128, "max": 2048, "default": 1024}, "lora_scale": {"min": 0.0, "max": 1.0, "default": 0.95}, } for param, constraints in numeric_params.items(): if param in kwargs and kwargs[param] is not None: value = kwargs[param] # Type conversion if needed if param in ["steps", "width", "height"]: try: value = int(value) except (ValueError, TypeError): raise ValueError(f"Parameter '{param}' must be an integer") else: try: value = float(value) except (ValueError, TypeError): raise ValueError(f"Parameter '{param}' must be a number") # Range validation if value < constraints["min"] or value > constraints["max"]: raise ValueError( f"Parameter '{param}' must be between {constraints['min']} and {constraints['max']}" ) validated[param] = value else: validated[param] = constraints["default"] # Special handling for seed if "seed" in kwargs and kwargs["seed"] is not None: try: seed = int(kwargs["seed"]) # -1 indicates random seed if seed == -1: try: seed = self._get_random_seed() except Exception as e: self.logger.warning(f"Failed to get random seed from API: {e}") # Fallback to Python's random import random seed = random.randint(0, 2**32 - 1) validated["seed"] = seed except (ValueError, TypeError): raise ValueError("Seed must be an integer") else: # Default to random seed validated["seed"] = self._get_random_seed() # Custom LoRA handling if "custom_lora" in kwargs and kwargs["custom_lora"]: validated["custom_lora"] = kwargs["custom_lora"] return validated def _download_image(self, url: str) -> str: """ Download image from URL and save to local file. Args: url: Image URL Returns: Local file path Raises: ConnectionError: If download fails """ try: response = requests.get(url, stream=True, timeout=30) response.raise_for_status() # Generate temporary file path file_ext = self._guess_extension(response.headers.get("Content-Type", "")) temp_path = os.path.join( self.image_save_dir, f"input_{uuid.uuid4().hex}{file_ext}" ) # Save image with open(temp_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) self.logger.debug(f"Downloaded image from {url} to {temp_path}") return temp_path except Exception as e: error_msg = f"Failed to download image from {url}: {str(e)}" self.logger.error(error_msg) raise ConnectionError(error_msg) from e def _guess_extension(self, content_type: str) -> str: """ Guess file extension from content type. Args: content_type: HTTP Content-Type header Returns: File extension (with dot) """ content_type = content_type.lower() if "jpeg" in content_type or "jpg" in content_type: return ".jpg" elif "png" in content_type: return ".png" elif "webp" in content_type: return ".webp" elif "gif" in content_type: return ".gif" else: return ".png" # Default to PNG def _get_random_seed(self) -> int: """ Get a random seed from the API. Returns: Random seed value Raises: RuntimeError: If random seed retrieval fails """ try: result = self.client.predict(api_name="/get_random_value") if isinstance(result, (int, float)): return int(result) else: raise ValueError(f"Unexpected result type: {type(result)}") except Exception as e: # Just log and re-raise as we have fallback in the validation method self.logger.warning(f"Failed to get random seed: {e}") raise def _handle_custom_lora(self, custom_lora: Optional[str]) -> None: """ Add or remove custom LoRA model. Args: custom_lora: Custom LoRA model string Raises: RuntimeError: If LoRA handling fails """ if not custom_lora: # Remove any existing custom LoRA try: self.client.predict(api_name="/remove_custom_lora") self.logger.debug("Removed custom LoRA") except Exception as e: error_msg = f"Failed to remove custom LoRA: {str(e)}" self.logger.error(error_msg) raise RuntimeError(error_msg) from e else: # Add custom LoRA try: self.client.predict( custom_lora=custom_lora, api_name="/add_custom_lora" ) self.logger.debug(f"Added custom LoRA: {custom_lora}") except Exception as e: error_msg = f"Failed to add custom LoRA '{custom_lora}': {str(e)}" self.logger.error(error_msg) raise RuntimeError(error_msg) from e def forward( self, prompt: str, image_input: Optional[str] = None, image_strength: Optional[float] = None, cfg_scale: Optional[float] = None, steps: Optional[int] = None, seed: Optional[int] = None, width: Optional[int] = None, height: Optional[int] = None, lora_scale: Optional[float] = None, custom_lora: Optional[str] = None, ) -> str: """ Generate an image with FLUX-LoRA. Args: prompt: Text description of the desired image image_input: Optional path or URL to input image for img2img image_strength: Strength of input image influence (0.0-1.0) cfg_scale: Guidance scale (1.0-30.0) steps: Number of sampling steps (10-100) seed: Random seed (-1 for random) width: Image width in pixels (128-2048) height: Image height in pixels (128-2048) lora_scale: LoRA influence scale (0.0-1.0) custom_lora: Custom LoRA model to use Returns: Formatted string with image generation results Raises: ValueError: If input validation fails ConnectionError: If API communication fails RuntimeError: If image generation fails """ # Step 1: Validate and normalize inputs try: params = self._validate_inputs( prompt=prompt, image_input=image_input, image_strength=image_strength, cfg_scale=cfg_scale, steps=steps, seed=seed, width=width, height=height, lora_scale=lora_scale, custom_lora=custom_lora, ) self.logger.debug(f"Validated parameters: {params}") except ValueError as e: return f"Parameter validation failed: {str(e)}" # Step 2: Handle custom LoRA if specified if "custom_lora" in params: try: custom_lora_value = params.pop("custom_lora") self._handle_custom_lora(custom_lora_value) except RuntimeError as e: return f"Custom LoRA setup failed: {str(e)}" # Step 3: Generate image try: # Prepare image input if provided img_param = None if "image_input" in params and params["image_input"]: from gradio_client import handle_file img_param = handle_file(params.pop("image_input")) # Call the API generation_args = { "prompt": params["prompt"], "image_strength": params["image_strength"], "cfg_scale": params["cfg_scale"], "steps": params["steps"], "randomize_seed": False, # We handle seed explicitly "seed": params["seed"], "width": params["width"], "height": params["height"], "lora_scale": params["lora_scale"], } # Add image input if available if img_param: generation_args["image_input"] = img_param self.logger.info(f"Generating image with params: {generation_args}") result = self.client.predict(api_name="/run_lora", **generation_args) # Process result if isinstance(result, tuple) and len(result) >= 2: image_path, actual_seed = result[0], result[1] # Save image to our directory try: output_path = self._save_image(image_path) image_result = ImageGenerationResult( image_path=output_path, seed=int(actual_seed) ) return self._format_result(image_result, params["prompt"]) except Exception as e: self.logger.error(f"Failed to save generated image: {e}") return f"Image generated but failed to save: {str(e)}" else: raise ValueError(f"Unexpected API response format: {result}") except Exception as e: error_msg = f"Image generation failed: {str(e)}" self.logger.error(error_msg) return error_msg def _save_image(self, image_path: str) -> str: """ Save generated image to specified directory. Args: image_path: Path to generated image from API Returns: Path to saved image Raises: IOError: If image saving fails """ try: # Load the image img = Image.open(image_path) # Generate timestamp-based filename timestamp = uuid.uuid4().hex[:8] output_filename = f"flux_lora_{timestamp}.png" output_path = os.path.join(self.image_save_dir, output_filename) # Save to our directory img.save(output_path) self.logger.debug(f"Saved image to {output_path}") return output_path except Exception as e: error_msg = f"Failed to save image: {str(e)}" self.logger.error(error_msg) raise IOError(error_msg) from e def _format_result(self, result: ImageGenerationResult, prompt: str) -> str: """ Format the image generation result as a string. Args: result: Image generation result prompt: Original prompt Returns: Formatted string with generation details """ lines = [ "📷 Image generated successfully!", f"🖼️ Image saved to: {result.image_path}", f"🌱 Seed used: {result.seed}", f"📝 Original prompt: {prompt}", ] # Add metadata if available if result.metadata: lines.append("📊 Additional metadata:") for key, value in result.metadata.items(): lines.append(f" - {key}: {value}") return "\n".join(lines) # ----------------------------------------------------------------------------- # UTILITY FUNCTIONS # ----------------------------------------------------------------------------- def download_image(url: str, output_dir: Optional[str] = None) -> str: """ Standalone utility to download an image from a URL. Args: url: Image URL output_dir: Directory to save image (created if doesn't exist) Returns: Path to downloaded image Raises: ValueError: If URL is invalid ConnectionError: If download fails IOError: If saving fails """ if not url.startswith(("http://", "https://")): raise ValueError(f"Invalid URL: {url}") # Setup output directory if output_dir is None: output_dir = os.path.join(tempfile.gettempdir(), "flux_lora_images") os.makedirs(output_dir, exist_ok=True) try: # Download image response = requests.get(url, stream=True, timeout=30) response.raise_for_status() # Determine file extension content_type = response.headers.get("Content-Type", "") ext = ".jpg" if "jpeg" in content_type.lower() else ".png" # Save image output_path = os.path.join(output_dir, f"download_{uuid.uuid4().hex}{ext}") with open(output_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return output_path except requests.RequestException as e: raise ConnectionError(f"Failed to download image: {str(e)}") except IOError as e: raise IOError(f"Failed to save image: {str(e)}")