OpenDeepResearch / scripts /flux_lora_tool.py
Leonardo
Sync local Space with Hub
eaaf050 verified
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The Footscray Coding Collective. All rights reserved.
"""
Zhou Protocol FLUX-LoRA Integration Tool
This module provides a Smolagents Tool implementation for interacting with FLUX-LoRA-DLC API.
It enables agents to generate high-quality images with customizable LoRA models.
Usage:
flux_tool = FluxLoRATool()
agent = CodeAgent(tools=[flux_tool], ...)
"""
import logging
import os
import tempfile
import uuid
from dataclasses import dataclass
from typing import Any, Dict, Optional
# Third-party
import requests
from gradio_client import Client
from PIL import Image
from smolagents import Tool
# -----------------------------------------------------------------------------
# CONSTANTS AND TYPE DEFINITIONS
# -----------------------------------------------------------------------------
@dataclass
class LoRAModelInfo:
"""Value object representing LoRA model information."""
name: str
description: Optional[str] = None
example_image_url: Optional[str] = None
@dataclass
class ImageGenerationResult:
"""Value object representing a generated image result."""
image_path: str
seed: int
metadata: Optional[Dict[str, Any]] = None
# -----------------------------------------------------------------------------
# CORE TOOL IMPLEMENTATION
# -----------------------------------------------------------------------------
class FluxLoRATool(Tool):
"""
Tool for generating images using FLUX-LoRA-DLC API.
This tool implements the Zhou Protocol integration patterns to provide
a clean, efficient interface for image generation using LoRA models.
"""
name = "flux_lora_generator"
description = """
Generates high-quality images using FLUX-LoRA models.
Can use custom LoRA models, adjust image parameters, and handle image inputs.
"""
inputs = {
"prompt": {
"type": "string",
"description": "Detailed description of the desired image.",
},
"image_input": {
"type": "string",
"description": "Optional URL or file path to input image for img2img generation.",
"optional": True,
},
"image_strength": {
"type": "float",
"description": "Strength of input image influence (0.0-1.0), where 1.0 maintains more of original image.",
"optional": True,
"default": 0.75,
},
"cfg_scale": {
"type": "float",
"description": "Guidance scale for prompt adherence (1.0-30.0).",
"optional": True,
"default": 3.5,
},
"steps": {
"type": "integer",
"description": "Number of sampling steps (10-100).",
"optional": True,
"default": 28,
},
"seed": {
"type": "integer",
"description": "Random seed for reproducibility. Use -1 for random seed.",
"optional": True,
"default": -1,
},
"width": {
"type": "integer",
"description": "Image width in pixels.",
"optional": True,
"default": 1024,
},
"height": {
"type": "integer",
"description": "Image height in pixels.",
"optional": True,
"default": 1024,
},
"lora_scale": {
"type": "float",
"description": "LoRA influence scale (0.0-1.0).",
"optional": True,
"default": 0.95,
},
"custom_lora": {
"type": "string",
"description": "Custom LoRA model to use. Leave empty for default.",
"optional": True,
},
}
output_type = "string"
def __init__(
self,
api_url: str = "xkerser/FLUX-LoRA-DLC",
image_save_dir: Optional[str] = None,
connection_timeout: int = 60,
verbose: bool = False,
):
"""
Initialize the FLUX-LoRA Tool with Zhou Protocol connection patterns.
Args:
api_url: URL or endpoint ID for the FLUX-LoRA-DLC API
image_save_dir: Directory to save generated images (created if doesn't exist)
connection_timeout: API connection timeout in seconds
verbose: Enable detailed logging
"""
super().__init__()
# Initialize logging
self.logger = logging.getLogger("flux_lora_tool")
self.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
# Set up client and storage directories
self.api_url = api_url
self.connection_timeout = connection_timeout
self._client = None # Lazy initialization
# Set up image storage directory
self.image_save_dir = image_save_dir or os.path.join(
tempfile.gettempdir(), "flux_lora_images"
)
os.makedirs(self.image_save_dir, exist_ok=True)
self.logger.info(
f"FluxLoRATool initialized. Images will be saved to: {self.image_save_dir}"
)
@property
def client(self) -> Client:
"""
Get or initialize the Gradio client with proper connection handling.
Returns:
Initialized Gradio client
Raises:
ConnectionError: If client initialization fails
"""
if self._client is None:
try:
self._client = Client(self.api_url, timeout=self.connection_timeout)
self.logger.debug(f"Gradio client initialized for: {self.api_url}")
except Exception as e:
error_msg = f"Failed to initialize FLUX-LoRA client: {str(e)}"
self.logger.error(error_msg)
raise ConnectionError(error_msg) from e
return self._client
def _validate_inputs(self, **kwargs) -> Dict[str, Any]:
"""
Validate and normalize input parameters with Zhou Protocol validation patterns.
Args:
**kwargs: Input parameters
Returns:
Validated and normalized parameters
Raises:
ValueError: If input validation fails
"""
validated = {}
# Required parameter: prompt
if not kwargs.get("prompt"):
raise ValueError("Prompt is required for image generation")
validated["prompt"] = kwargs["prompt"]
# Image input handling
if "image_input" in kwargs and kwargs["image_input"]:
input_image = kwargs["image_input"]
# Handle URL vs. local file
if input_image.startswith(("http://", "https://")):
# We'll need to download and process this
validated["image_input"] = self._download_image(input_image)
else:
# Check if file exists
if not os.path.exists(input_image):
raise ValueError(f"Image file not found: {input_image}")
validated["image_input"] = input_image
# Numeric parameter validation with constraints
numeric_params = {
"image_strength": {"min": 0.0, "max": 1.0, "default": 0.75},
"cfg_scale": {"min": 1.0, "max": 30.0, "default": 3.5},
"steps": {"min": 10, "max": 100, "default": 28},
"width": {"min": 128, "max": 2048, "default": 1024},
"height": {"min": 128, "max": 2048, "default": 1024},
"lora_scale": {"min": 0.0, "max": 1.0, "default": 0.95},
}
for param, constraints in numeric_params.items():
if param in kwargs and kwargs[param] is not None:
value = kwargs[param]
# Type conversion if needed
if param in ["steps", "width", "height"]:
try:
value = int(value)
except (ValueError, TypeError):
raise ValueError(f"Parameter '{param}' must be an integer")
else:
try:
value = float(value)
except (ValueError, TypeError):
raise ValueError(f"Parameter '{param}' must be a number")
# Range validation
if value < constraints["min"] or value > constraints["max"]:
raise ValueError(
f"Parameter '{param}' must be between {constraints['min']} and {constraints['max']}"
)
validated[param] = value
else:
validated[param] = constraints["default"]
# Special handling for seed
if "seed" in kwargs and kwargs["seed"] is not None:
try:
seed = int(kwargs["seed"])
# -1 indicates random seed
if seed == -1:
try:
seed = self._get_random_seed()
except Exception as e:
self.logger.warning(f"Failed to get random seed from API: {e}")
# Fallback to Python's random
import random
seed = random.randint(0, 2**32 - 1)
validated["seed"] = seed
except (ValueError, TypeError):
raise ValueError("Seed must be an integer")
else:
# Default to random seed
validated["seed"] = self._get_random_seed()
# Custom LoRA handling
if "custom_lora" in kwargs and kwargs["custom_lora"]:
validated["custom_lora"] = kwargs["custom_lora"]
return validated
def _download_image(self, url: str) -> str:
"""
Download image from URL and save to local file.
Args:
url: Image URL
Returns:
Local file path
Raises:
ConnectionError: If download fails
"""
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Generate temporary file path
file_ext = self._guess_extension(response.headers.get("Content-Type", ""))
temp_path = os.path.join(
self.image_save_dir, f"input_{uuid.uuid4().hex}{file_ext}"
)
# Save image
with open(temp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
self.logger.debug(f"Downloaded image from {url} to {temp_path}")
return temp_path
except Exception as e:
error_msg = f"Failed to download image from {url}: {str(e)}"
self.logger.error(error_msg)
raise ConnectionError(error_msg) from e
def _guess_extension(self, content_type: str) -> str:
"""
Guess file extension from content type.
Args:
content_type: HTTP Content-Type header
Returns:
File extension (with dot)
"""
content_type = content_type.lower()
if "jpeg" in content_type or "jpg" in content_type:
return ".jpg"
elif "png" in content_type:
return ".png"
elif "webp" in content_type:
return ".webp"
elif "gif" in content_type:
return ".gif"
else:
return ".png" # Default to PNG
def _get_random_seed(self) -> int:
"""
Get a random seed from the API.
Returns:
Random seed value
Raises:
RuntimeError: If random seed retrieval fails
"""
try:
result = self.client.predict(api_name="/get_random_value")
if isinstance(result, (int, float)):
return int(result)
else:
raise ValueError(f"Unexpected result type: {type(result)}")
except Exception as e:
# Just log and re-raise as we have fallback in the validation method
self.logger.warning(f"Failed to get random seed: {e}")
raise
def _handle_custom_lora(self, custom_lora: Optional[str]) -> None:
"""
Add or remove custom LoRA model.
Args:
custom_lora: Custom LoRA model string
Raises:
RuntimeError: If LoRA handling fails
"""
if not custom_lora:
# Remove any existing custom LoRA
try:
self.client.predict(api_name="/remove_custom_lora")
self.logger.debug("Removed custom LoRA")
except Exception as e:
error_msg = f"Failed to remove custom LoRA: {str(e)}"
self.logger.error(error_msg)
raise RuntimeError(error_msg) from e
else:
# Add custom LoRA
try:
self.client.predict(
custom_lora=custom_lora, api_name="/add_custom_lora"
)
self.logger.debug(f"Added custom LoRA: {custom_lora}")
except Exception as e:
error_msg = f"Failed to add custom LoRA '{custom_lora}': {str(e)}"
self.logger.error(error_msg)
raise RuntimeError(error_msg) from e
def forward(
self,
prompt: str,
image_input: Optional[str] = None,
image_strength: Optional[float] = None,
cfg_scale: Optional[float] = None,
steps: Optional[int] = None,
seed: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
lora_scale: Optional[float] = None,
custom_lora: Optional[str] = None,
) -> str:
"""
Generate an image with FLUX-LoRA.
Args:
prompt: Text description of the desired image
image_input: Optional path or URL to input image for img2img
image_strength: Strength of input image influence (0.0-1.0)
cfg_scale: Guidance scale (1.0-30.0)
steps: Number of sampling steps (10-100)
seed: Random seed (-1 for random)
width: Image width in pixels (128-2048)
height: Image height in pixels (128-2048)
lora_scale: LoRA influence scale (0.0-1.0)
custom_lora: Custom LoRA model to use
Returns:
Formatted string with image generation results
Raises:
ValueError: If input validation fails
ConnectionError: If API communication fails
RuntimeError: If image generation fails
"""
# Step 1: Validate and normalize inputs
try:
params = self._validate_inputs(
prompt=prompt,
image_input=image_input,
image_strength=image_strength,
cfg_scale=cfg_scale,
steps=steps,
seed=seed,
width=width,
height=height,
lora_scale=lora_scale,
custom_lora=custom_lora,
)
self.logger.debug(f"Validated parameters: {params}")
except ValueError as e:
return f"Parameter validation failed: {str(e)}"
# Step 2: Handle custom LoRA if specified
if "custom_lora" in params:
try:
custom_lora_value = params.pop("custom_lora")
self._handle_custom_lora(custom_lora_value)
except RuntimeError as e:
return f"Custom LoRA setup failed: {str(e)}"
# Step 3: Generate image
try:
# Prepare image input if provided
img_param = None
if "image_input" in params and params["image_input"]:
from gradio_client import handle_file
img_param = handle_file(params.pop("image_input"))
# Call the API
generation_args = {
"prompt": params["prompt"],
"image_strength": params["image_strength"],
"cfg_scale": params["cfg_scale"],
"steps": params["steps"],
"randomize_seed": False, # We handle seed explicitly
"seed": params["seed"],
"width": params["width"],
"height": params["height"],
"lora_scale": params["lora_scale"],
}
# Add image input if available
if img_param:
generation_args["image_input"] = img_param
self.logger.info(f"Generating image with params: {generation_args}")
result = self.client.predict(api_name="/run_lora", **generation_args)
# Process result
if isinstance(result, tuple) and len(result) >= 2:
image_path, actual_seed = result[0], result[1]
# Save image to our directory
try:
output_path = self._save_image(image_path)
image_result = ImageGenerationResult(
image_path=output_path, seed=int(actual_seed)
)
return self._format_result(image_result, params["prompt"])
except Exception as e:
self.logger.error(f"Failed to save generated image: {e}")
return f"Image generated but failed to save: {str(e)}"
else:
raise ValueError(f"Unexpected API response format: {result}")
except Exception as e:
error_msg = f"Image generation failed: {str(e)}"
self.logger.error(error_msg)
return error_msg
def _save_image(self, image_path: str) -> str:
"""
Save generated image to specified directory.
Args:
image_path: Path to generated image from API
Returns:
Path to saved image
Raises:
IOError: If image saving fails
"""
try:
# Load the image
img = Image.open(image_path)
# Generate timestamp-based filename
timestamp = uuid.uuid4().hex[:8]
output_filename = f"flux_lora_{timestamp}.png"
output_path = os.path.join(self.image_save_dir, output_filename)
# Save to our directory
img.save(output_path)
self.logger.debug(f"Saved image to {output_path}")
return output_path
except Exception as e:
error_msg = f"Failed to save image: {str(e)}"
self.logger.error(error_msg)
raise IOError(error_msg) from e
def _format_result(self, result: ImageGenerationResult, prompt: str) -> str:
"""
Format the image generation result as a string.
Args:
result: Image generation result
prompt: Original prompt
Returns:
Formatted string with generation details
"""
lines = [
"๐Ÿ“ท Image generated successfully!",
f"๐Ÿ–ผ๏ธ Image saved to: {result.image_path}",
f"๐ŸŒฑ Seed used: {result.seed}",
f"๐Ÿ“ Original prompt: {prompt}",
]
# Add metadata if available
if result.metadata:
lines.append("๐Ÿ“Š Additional metadata:")
for key, value in result.metadata.items():
lines.append(f" - {key}: {value}")
return "\n".join(lines)
# -----------------------------------------------------------------------------
# UTILITY FUNCTIONS
# -----------------------------------------------------------------------------
def download_image(url: str, output_dir: Optional[str] = None) -> str:
"""
Standalone utility to download an image from a URL.
Args:
url: Image URL
output_dir: Directory to save image (created if doesn't exist)
Returns:
Path to downloaded image
Raises:
ValueError: If URL is invalid
ConnectionError: If download fails
IOError: If saving fails
"""
if not url.startswith(("http://", "https://")):
raise ValueError(f"Invalid URL: {url}")
# Setup output directory
if output_dir is None:
output_dir = os.path.join(tempfile.gettempdir(), "flux_lora_images")
os.makedirs(output_dir, exist_ok=True)
try:
# Download image
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Determine file extension
content_type = response.headers.get("Content-Type", "")
ext = ".jpg" if "jpeg" in content_type.lower() else ".png"
# Save image
output_path = os.path.join(output_dir, f"download_{uuid.uuid4().hex}{ext}")
with open(output_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return output_path
except requests.RequestException as e:
raise ConnectionError(f"Failed to download image: {str(e)}")
except IOError as e:
raise IOError(f"Failed to save image: {str(e)}")