Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2024 The Footscray Coding Collective. All rights reserved. | |
| """ | |
| Zhou Protocol FLUX-LoRA Integration Tool | |
| This module provides a Smolagents Tool implementation for interacting with FLUX-LoRA-DLC API. | |
| It enables agents to generate high-quality images with customizable LoRA models. | |
| Usage: | |
| flux_tool = FluxLoRATool() | |
| agent = CodeAgent(tools=[flux_tool], ...) | |
| """ | |
| import logging | |
| import os | |
| import tempfile | |
| import uuid | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Optional | |
| # Third-party | |
| import requests | |
| from gradio_client import Client | |
| from PIL import Image | |
| from smolagents import Tool | |
| # ----------------------------------------------------------------------------- | |
| # CONSTANTS AND TYPE DEFINITIONS | |
| # ----------------------------------------------------------------------------- | |
| class LoRAModelInfo: | |
| """Value object representing LoRA model information.""" | |
| name: str | |
| description: Optional[str] = None | |
| example_image_url: Optional[str] = None | |
| class ImageGenerationResult: | |
| """Value object representing a generated image result.""" | |
| image_path: str | |
| seed: int | |
| metadata: Optional[Dict[str, Any]] = None | |
| # ----------------------------------------------------------------------------- | |
| # CORE TOOL IMPLEMENTATION | |
| # ----------------------------------------------------------------------------- | |
| class FluxLoRATool(Tool): | |
| """ | |
| Tool for generating images using FLUX-LoRA-DLC API. | |
| This tool implements the Zhou Protocol integration patterns to provide | |
| a clean, efficient interface for image generation using LoRA models. | |
| """ | |
| name = "flux_lora_generator" | |
| description = """ | |
| Generates high-quality images using FLUX-LoRA models. | |
| Can use custom LoRA models, adjust image parameters, and handle image inputs. | |
| """ | |
| inputs = { | |
| "prompt": { | |
| "type": "string", | |
| "description": "Detailed description of the desired image.", | |
| }, | |
| "image_input": { | |
| "type": "string", | |
| "description": "Optional URL or file path to input image for img2img generation.", | |
| "optional": True, | |
| }, | |
| "image_strength": { | |
| "type": "float", | |
| "description": "Strength of input image influence (0.0-1.0), where 1.0 maintains more of original image.", | |
| "optional": True, | |
| "default": 0.75, | |
| }, | |
| "cfg_scale": { | |
| "type": "float", | |
| "description": "Guidance scale for prompt adherence (1.0-30.0).", | |
| "optional": True, | |
| "default": 3.5, | |
| }, | |
| "steps": { | |
| "type": "integer", | |
| "description": "Number of sampling steps (10-100).", | |
| "optional": True, | |
| "default": 28, | |
| }, | |
| "seed": { | |
| "type": "integer", | |
| "description": "Random seed for reproducibility. Use -1 for random seed.", | |
| "optional": True, | |
| "default": -1, | |
| }, | |
| "width": { | |
| "type": "integer", | |
| "description": "Image width in pixels.", | |
| "optional": True, | |
| "default": 1024, | |
| }, | |
| "height": { | |
| "type": "integer", | |
| "description": "Image height in pixels.", | |
| "optional": True, | |
| "default": 1024, | |
| }, | |
| "lora_scale": { | |
| "type": "float", | |
| "description": "LoRA influence scale (0.0-1.0).", | |
| "optional": True, | |
| "default": 0.95, | |
| }, | |
| "custom_lora": { | |
| "type": "string", | |
| "description": "Custom LoRA model to use. Leave empty for default.", | |
| "optional": True, | |
| }, | |
| } | |
| output_type = "string" | |
| def __init__( | |
| self, | |
| api_url: str = "xkerser/FLUX-LoRA-DLC", | |
| image_save_dir: Optional[str] = None, | |
| connection_timeout: int = 60, | |
| verbose: bool = False, | |
| ): | |
| """ | |
| Initialize the FLUX-LoRA Tool with Zhou Protocol connection patterns. | |
| Args: | |
| api_url: URL or endpoint ID for the FLUX-LoRA-DLC API | |
| image_save_dir: Directory to save generated images (created if doesn't exist) | |
| connection_timeout: API connection timeout in seconds | |
| verbose: Enable detailed logging | |
| """ | |
| super().__init__() | |
| # Initialize logging | |
| self.logger = logging.getLogger("flux_lora_tool") | |
| self.logger.setLevel(logging.DEBUG if verbose else logging.INFO) | |
| # Set up client and storage directories | |
| self.api_url = api_url | |
| self.connection_timeout = connection_timeout | |
| self._client = None # Lazy initialization | |
| # Set up image storage directory | |
| self.image_save_dir = image_save_dir or os.path.join( | |
| tempfile.gettempdir(), "flux_lora_images" | |
| ) | |
| os.makedirs(self.image_save_dir, exist_ok=True) | |
| self.logger.info( | |
| f"FluxLoRATool initialized. Images will be saved to: {self.image_save_dir}" | |
| ) | |
| def client(self) -> Client: | |
| """ | |
| Get or initialize the Gradio client with proper connection handling. | |
| Returns: | |
| Initialized Gradio client | |
| Raises: | |
| ConnectionError: If client initialization fails | |
| """ | |
| if self._client is None: | |
| try: | |
| self._client = Client(self.api_url, timeout=self.connection_timeout) | |
| self.logger.debug(f"Gradio client initialized for: {self.api_url}") | |
| except Exception as e: | |
| error_msg = f"Failed to initialize FLUX-LoRA client: {str(e)}" | |
| self.logger.error(error_msg) | |
| raise ConnectionError(error_msg) from e | |
| return self._client | |
| def _validate_inputs(self, **kwargs) -> Dict[str, Any]: | |
| """ | |
| Validate and normalize input parameters with Zhou Protocol validation patterns. | |
| Args: | |
| **kwargs: Input parameters | |
| Returns: | |
| Validated and normalized parameters | |
| Raises: | |
| ValueError: If input validation fails | |
| """ | |
| validated = {} | |
| # Required parameter: prompt | |
| if not kwargs.get("prompt"): | |
| raise ValueError("Prompt is required for image generation") | |
| validated["prompt"] = kwargs["prompt"] | |
| # Image input handling | |
| if "image_input" in kwargs and kwargs["image_input"]: | |
| input_image = kwargs["image_input"] | |
| # Handle URL vs. local file | |
| if input_image.startswith(("http://", "https://")): | |
| # We'll need to download and process this | |
| validated["image_input"] = self._download_image(input_image) | |
| else: | |
| # Check if file exists | |
| if not os.path.exists(input_image): | |
| raise ValueError(f"Image file not found: {input_image}") | |
| validated["image_input"] = input_image | |
| # Numeric parameter validation with constraints | |
| numeric_params = { | |
| "image_strength": {"min": 0.0, "max": 1.0, "default": 0.75}, | |
| "cfg_scale": {"min": 1.0, "max": 30.0, "default": 3.5}, | |
| "steps": {"min": 10, "max": 100, "default": 28}, | |
| "width": {"min": 128, "max": 2048, "default": 1024}, | |
| "height": {"min": 128, "max": 2048, "default": 1024}, | |
| "lora_scale": {"min": 0.0, "max": 1.0, "default": 0.95}, | |
| } | |
| for param, constraints in numeric_params.items(): | |
| if param in kwargs and kwargs[param] is not None: | |
| value = kwargs[param] | |
| # Type conversion if needed | |
| if param in ["steps", "width", "height"]: | |
| try: | |
| value = int(value) | |
| except (ValueError, TypeError): | |
| raise ValueError(f"Parameter '{param}' must be an integer") | |
| else: | |
| try: | |
| value = float(value) | |
| except (ValueError, TypeError): | |
| raise ValueError(f"Parameter '{param}' must be a number") | |
| # Range validation | |
| if value < constraints["min"] or value > constraints["max"]: | |
| raise ValueError( | |
| f"Parameter '{param}' must be between {constraints['min']} and {constraints['max']}" | |
| ) | |
| validated[param] = value | |
| else: | |
| validated[param] = constraints["default"] | |
| # Special handling for seed | |
| if "seed" in kwargs and kwargs["seed"] is not None: | |
| try: | |
| seed = int(kwargs["seed"]) | |
| # -1 indicates random seed | |
| if seed == -1: | |
| try: | |
| seed = self._get_random_seed() | |
| except Exception as e: | |
| self.logger.warning(f"Failed to get random seed from API: {e}") | |
| # Fallback to Python's random | |
| import random | |
| seed = random.randint(0, 2**32 - 1) | |
| validated["seed"] = seed | |
| except (ValueError, TypeError): | |
| raise ValueError("Seed must be an integer") | |
| else: | |
| # Default to random seed | |
| validated["seed"] = self._get_random_seed() | |
| # Custom LoRA handling | |
| if "custom_lora" in kwargs and kwargs["custom_lora"]: | |
| validated["custom_lora"] = kwargs["custom_lora"] | |
| return validated | |
| def _download_image(self, url: str) -> str: | |
| """ | |
| Download image from URL and save to local file. | |
| Args: | |
| url: Image URL | |
| Returns: | |
| Local file path | |
| Raises: | |
| ConnectionError: If download fails | |
| """ | |
| try: | |
| response = requests.get(url, stream=True, timeout=30) | |
| response.raise_for_status() | |
| # Generate temporary file path | |
| file_ext = self._guess_extension(response.headers.get("Content-Type", "")) | |
| temp_path = os.path.join( | |
| self.image_save_dir, f"input_{uuid.uuid4().hex}{file_ext}" | |
| ) | |
| # Save image | |
| with open(temp_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| self.logger.debug(f"Downloaded image from {url} to {temp_path}") | |
| return temp_path | |
| except Exception as e: | |
| error_msg = f"Failed to download image from {url}: {str(e)}" | |
| self.logger.error(error_msg) | |
| raise ConnectionError(error_msg) from e | |
| def _guess_extension(self, content_type: str) -> str: | |
| """ | |
| Guess file extension from content type. | |
| Args: | |
| content_type: HTTP Content-Type header | |
| Returns: | |
| File extension (with dot) | |
| """ | |
| content_type = content_type.lower() | |
| if "jpeg" in content_type or "jpg" in content_type: | |
| return ".jpg" | |
| elif "png" in content_type: | |
| return ".png" | |
| elif "webp" in content_type: | |
| return ".webp" | |
| elif "gif" in content_type: | |
| return ".gif" | |
| else: | |
| return ".png" # Default to PNG | |
| def _get_random_seed(self) -> int: | |
| """ | |
| Get a random seed from the API. | |
| Returns: | |
| Random seed value | |
| Raises: | |
| RuntimeError: If random seed retrieval fails | |
| """ | |
| try: | |
| result = self.client.predict(api_name="/get_random_value") | |
| if isinstance(result, (int, float)): | |
| return int(result) | |
| else: | |
| raise ValueError(f"Unexpected result type: {type(result)}") | |
| except Exception as e: | |
| # Just log and re-raise as we have fallback in the validation method | |
| self.logger.warning(f"Failed to get random seed: {e}") | |
| raise | |
| def _handle_custom_lora(self, custom_lora: Optional[str]) -> None: | |
| """ | |
| Add or remove custom LoRA model. | |
| Args: | |
| custom_lora: Custom LoRA model string | |
| Raises: | |
| RuntimeError: If LoRA handling fails | |
| """ | |
| if not custom_lora: | |
| # Remove any existing custom LoRA | |
| try: | |
| self.client.predict(api_name="/remove_custom_lora") | |
| self.logger.debug("Removed custom LoRA") | |
| except Exception as e: | |
| error_msg = f"Failed to remove custom LoRA: {str(e)}" | |
| self.logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |
| else: | |
| # Add custom LoRA | |
| try: | |
| self.client.predict( | |
| custom_lora=custom_lora, api_name="/add_custom_lora" | |
| ) | |
| self.logger.debug(f"Added custom LoRA: {custom_lora}") | |
| except Exception as e: | |
| error_msg = f"Failed to add custom LoRA '{custom_lora}': {str(e)}" | |
| self.logger.error(error_msg) | |
| raise RuntimeError(error_msg) from e | |
| def forward( | |
| self, | |
| prompt: str, | |
| image_input: Optional[str] = None, | |
| image_strength: Optional[float] = None, | |
| cfg_scale: Optional[float] = None, | |
| steps: Optional[int] = None, | |
| seed: Optional[int] = None, | |
| width: Optional[int] = None, | |
| height: Optional[int] = None, | |
| lora_scale: Optional[float] = None, | |
| custom_lora: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Generate an image with FLUX-LoRA. | |
| Args: | |
| prompt: Text description of the desired image | |
| image_input: Optional path or URL to input image for img2img | |
| image_strength: Strength of input image influence (0.0-1.0) | |
| cfg_scale: Guidance scale (1.0-30.0) | |
| steps: Number of sampling steps (10-100) | |
| seed: Random seed (-1 for random) | |
| width: Image width in pixels (128-2048) | |
| height: Image height in pixels (128-2048) | |
| lora_scale: LoRA influence scale (0.0-1.0) | |
| custom_lora: Custom LoRA model to use | |
| Returns: | |
| Formatted string with image generation results | |
| Raises: | |
| ValueError: If input validation fails | |
| ConnectionError: If API communication fails | |
| RuntimeError: If image generation fails | |
| """ | |
| # Step 1: Validate and normalize inputs | |
| try: | |
| params = self._validate_inputs( | |
| prompt=prompt, | |
| image_input=image_input, | |
| image_strength=image_strength, | |
| cfg_scale=cfg_scale, | |
| steps=steps, | |
| seed=seed, | |
| width=width, | |
| height=height, | |
| lora_scale=lora_scale, | |
| custom_lora=custom_lora, | |
| ) | |
| self.logger.debug(f"Validated parameters: {params}") | |
| except ValueError as e: | |
| return f"Parameter validation failed: {str(e)}" | |
| # Step 2: Handle custom LoRA if specified | |
| if "custom_lora" in params: | |
| try: | |
| custom_lora_value = params.pop("custom_lora") | |
| self._handle_custom_lora(custom_lora_value) | |
| except RuntimeError as e: | |
| return f"Custom LoRA setup failed: {str(e)}" | |
| # Step 3: Generate image | |
| try: | |
| # Prepare image input if provided | |
| img_param = None | |
| if "image_input" in params and params["image_input"]: | |
| from gradio_client import handle_file | |
| img_param = handle_file(params.pop("image_input")) | |
| # Call the API | |
| generation_args = { | |
| "prompt": params["prompt"], | |
| "image_strength": params["image_strength"], | |
| "cfg_scale": params["cfg_scale"], | |
| "steps": params["steps"], | |
| "randomize_seed": False, # We handle seed explicitly | |
| "seed": params["seed"], | |
| "width": params["width"], | |
| "height": params["height"], | |
| "lora_scale": params["lora_scale"], | |
| } | |
| # Add image input if available | |
| if img_param: | |
| generation_args["image_input"] = img_param | |
| self.logger.info(f"Generating image with params: {generation_args}") | |
| result = self.client.predict(api_name="/run_lora", **generation_args) | |
| # Process result | |
| if isinstance(result, tuple) and len(result) >= 2: | |
| image_path, actual_seed = result[0], result[1] | |
| # Save image to our directory | |
| try: | |
| output_path = self._save_image(image_path) | |
| image_result = ImageGenerationResult( | |
| image_path=output_path, seed=int(actual_seed) | |
| ) | |
| return self._format_result(image_result, params["prompt"]) | |
| except Exception as e: | |
| self.logger.error(f"Failed to save generated image: {e}") | |
| return f"Image generated but failed to save: {str(e)}" | |
| else: | |
| raise ValueError(f"Unexpected API response format: {result}") | |
| except Exception as e: | |
| error_msg = f"Image generation failed: {str(e)}" | |
| self.logger.error(error_msg) | |
| return error_msg | |
| def _save_image(self, image_path: str) -> str: | |
| """ | |
| Save generated image to specified directory. | |
| Args: | |
| image_path: Path to generated image from API | |
| Returns: | |
| Path to saved image | |
| Raises: | |
| IOError: If image saving fails | |
| """ | |
| try: | |
| # Load the image | |
| img = Image.open(image_path) | |
| # Generate timestamp-based filename | |
| timestamp = uuid.uuid4().hex[:8] | |
| output_filename = f"flux_lora_{timestamp}.png" | |
| output_path = os.path.join(self.image_save_dir, output_filename) | |
| # Save to our directory | |
| img.save(output_path) | |
| self.logger.debug(f"Saved image to {output_path}") | |
| return output_path | |
| except Exception as e: | |
| error_msg = f"Failed to save image: {str(e)}" | |
| self.logger.error(error_msg) | |
| raise IOError(error_msg) from e | |
| def _format_result(self, result: ImageGenerationResult, prompt: str) -> str: | |
| """ | |
| Format the image generation result as a string. | |
| Args: | |
| result: Image generation result | |
| prompt: Original prompt | |
| Returns: | |
| Formatted string with generation details | |
| """ | |
| lines = [ | |
| "๐ท Image generated successfully!", | |
| f"๐ผ๏ธ Image saved to: {result.image_path}", | |
| f"๐ฑ Seed used: {result.seed}", | |
| f"๐ Original prompt: {prompt}", | |
| ] | |
| # Add metadata if available | |
| if result.metadata: | |
| lines.append("๐ Additional metadata:") | |
| for key, value in result.metadata.items(): | |
| lines.append(f" - {key}: {value}") | |
| return "\n".join(lines) | |
| # ----------------------------------------------------------------------------- | |
| # UTILITY FUNCTIONS | |
| # ----------------------------------------------------------------------------- | |
| def download_image(url: str, output_dir: Optional[str] = None) -> str: | |
| """ | |
| Standalone utility to download an image from a URL. | |
| Args: | |
| url: Image URL | |
| output_dir: Directory to save image (created if doesn't exist) | |
| Returns: | |
| Path to downloaded image | |
| Raises: | |
| ValueError: If URL is invalid | |
| ConnectionError: If download fails | |
| IOError: If saving fails | |
| """ | |
| if not url.startswith(("http://", "https://")): | |
| raise ValueError(f"Invalid URL: {url}") | |
| # Setup output directory | |
| if output_dir is None: | |
| output_dir = os.path.join(tempfile.gettempdir(), "flux_lora_images") | |
| os.makedirs(output_dir, exist_ok=True) | |
| try: | |
| # Download image | |
| response = requests.get(url, stream=True, timeout=30) | |
| response.raise_for_status() | |
| # Determine file extension | |
| content_type = response.headers.get("Content-Type", "") | |
| ext = ".jpg" if "jpeg" in content_type.lower() else ".png" | |
| # Save image | |
| output_path = os.path.join(output_dir, f"download_{uuid.uuid4().hex}{ext}") | |
| with open(output_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return output_path | |
| except requests.RequestException as e: | |
| raise ConnectionError(f"Failed to download image: {str(e)}") | |
| except IOError as e: | |
| raise IOError(f"Failed to save image: {str(e)}") | |