#!/usr/bin/env python3 """ Trouter-Imagine-1 Utilities and Helper Functions Apache 2.0 License Comprehensive utility module providing: - Prompt enhancement and optimization - Image post-processing - Metadata management - Performance monitoring - Configuration management - Quality assessment - Batch processing helpers - File management - API wrappers - Advanced preprocessing """ import torch from PIL import Image, ImageEnhance, ImageFilter, ImageDraw, ImageFont import numpy as np from typing import List, Dict, Tuple, Optional, Union import json import os import hashlib from pathlib import Path from datetime import datetime import re import logging from dataclasses import dataclass, asdict import time from collections import defaultdict # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ============================================================================ # DATA CLASSES FOR CONFIGURATION # ============================================================================ @dataclass class GenerationConfig: """Configuration for image generation""" prompt: str negative_prompt: str = "" width: int = 512 height: int = 512 num_inference_steps: int = 30 guidance_scale: float = 7.5 seed: Optional[int] = None num_images: int = 1 def to_dict(self) -> Dict: return asdict(self) @classmethod def from_dict(cls, data: Dict) -> 'GenerationConfig': return cls(**data) def validate(self) -> bool: """Validate configuration parameters""" if self.width % 8 != 0 or self.height % 8 != 0: raise ValueError("Width and height must be multiples of 8") if self.num_inference_steps < 1: raise ValueError("num_inference_steps must be at least 1") if self.guidance_scale < 0: raise ValueError("guidance_scale must be positive") return True @dataclass class GenerationMetadata: """Metadata for generated images""" prompt: str negative_prompt: str model_id: str width: int height: int num_inference_steps: int guidance_scale: float seed: int timestamp: str generation_time: float scheduler: str = "unknown" def to_json(self) -> str: return json.dumps(asdict(self), indent=2) @classmethod def from_json(cls, json_str: str) -> 'GenerationMetadata': return cls(**json.loads(json_str)) # ============================================================================ # PROMPT ENHANCEMENT # ============================================================================ class PromptEnhancer: """Enhance and optimize prompts for better generation""" QUALITY_BOOSTERS = [ "highly detailed", "professional", "4k", "ultra detailed", "sharp focus", "intricate details" ] STYLE_KEYWORDS = { "photo": ["photography", "realistic", "photorealistic", "sharp focus"], "art": ["digital art", "concept art", "artistic", "detailed"], "paint": ["oil painting", "painterly", "brushstrokes", "canvas"], "anime": ["anime style", "manga", "cel shaded", "vibrant"], "3d": ["3d render", "octane render", "unreal engine", "cgi"] } NEGATIVE_DEFAULTS = [ "blurry", "low quality", "distorted", "deformed", "ugly", "bad anatomy", "watermark", "signature" ] @staticmethod def enhance_prompt( prompt: str, style: Optional[str] = None, add_quality: bool = True, add_details: bool = True ) -> str: """ Enhance a prompt with quality boosters and style keywords Args: prompt: Base prompt style: Style to apply (photo, art, paint, anime, 3d) add_quality: Add quality boosters add_details: Add detail-related keywords Returns: Enhanced prompt """ enhanced = prompt.strip() # Add style keywords if style and style.lower() in PromptEnhancer.STYLE_KEYWORDS: style_words = PromptEnhancer.STYLE_KEYWORDS[style.lower()] enhanced += ", " + ", ".join(style_words[:2]) # Add quality boosters if add_quality: quality_words = PromptEnhancer.QUALITY_BOOSTERS[:3] enhanced += ", " + ", ".join(quality_words) return enhanced @staticmethod def build_negative_prompt( base_negative: str = "", include_defaults: bool = True, subject_type: Optional[str] = None ) -> str: """ Build a comprehensive negative prompt Args: base_negative: User-provided negative prompt include_defaults: Include default negative terms subject_type: Type of subject (person, animal, landscape, etc.) Returns: Enhanced negative prompt """ negatives = [] if base_negative: negatives.append(base_negative) if include_defaults: negatives.extend(PromptEnhancer.NEGATIVE_DEFAULTS) # Subject-specific negatives subject_negatives = { "person": ["extra limbs", "extra fingers", "fused fingers", "bad hands"], "animal": ["extra legs", "incorrect anatomy", "fused limbs"], "face": ["asymmetric eyes", "crossed eyes", "bad teeth"], "landscape": ["oversaturated", "underexposed", "poor composition"] } if subject_type and subject_type.lower() in subject_negatives: negatives.extend(subject_negatives[subject_type.lower()]) return ", ".join(negatives) @staticmethod def extract_keywords(prompt: str) -> List[str]: """Extract important keywords from a prompt""" # Remove common words stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'with', 'by', 'for'} words = prompt.lower().split() keywords = [w.strip('.,!?;:') for w in words if w not in stop_words] return keywords @staticmethod def validate_prompt(prompt: str) -> Tuple[bool, List[str]]: """ Validate a prompt and return warnings Returns: (is_valid, list_of_warnings) """ warnings = [] if len(prompt.strip()) < 3: warnings.append("Prompt is very short, consider adding more detail") if len(prompt) > 500: warnings.append("Prompt is very long, may be truncated") # Check for common issues if "high quality" in prompt.lower() and "low quality" in prompt.lower(): warnings.append("Contradictory quality terms detected") # Check for excessive punctuation if prompt.count(',') > 20: warnings.append("Too many commas, consider simplifying") return len(warnings) == 0, warnings # ============================================================================ # IMAGE POST-PROCESSING # ============================================================================ class ImageProcessor: """Post-processing utilities for generated images""" @staticmethod def enhance_image( image: Image.Image, brightness: float = 1.0, contrast: float = 1.0, saturation: float = 1.0, sharpness: float = 1.0 ) -> Image.Image: """ Enhance image with various adjustments Args: image: Input PIL Image brightness: Brightness factor (1.0 = no change) contrast: Contrast factor saturation: Color saturation factor sharpness: Sharpness factor Returns: Enhanced image """ enhanced = image if brightness != 1.0: enhancer = ImageEnhance.Brightness(enhanced) enhanced = enhancer.enhance(brightness) if contrast != 1.0: enhancer = ImageEnhance.Contrast(enhanced) enhanced = enhancer.enhance(contrast) if saturation != 1.0: enhancer = ImageEnhance.Color(enhanced) enhanced = enhancer.enhance(saturation) if sharpness != 1.0: enhancer = ImageEnhance.Sharpness(enhanced) enhanced = enhancer.enhance(sharpness) return enhanced @staticmethod def apply_filter( image: Image.Image, filter_type: str = "none" ) -> Image.Image: """ Apply various filters to image Args: image: Input image filter_type: Type of filter (blur, sharpen, edge_enhance, smooth, detail) Returns: Filtered image """ filters = { "blur": ImageFilter.BLUR, "sharpen": ImageFilter.SHARPEN, "edge_enhance": ImageFilter.EDGE_ENHANCE, "edge_enhance_more": ImageFilter.EDGE_ENHANCE_MORE, "smooth": ImageFilter.SMOOTH, "smooth_more": ImageFilter.SMOOTH_MORE, "detail": ImageFilter.DETAIL } if filter_type.lower() in filters: return image.filter(filters[filter_type.lower()]) return image @staticmethod def upscale_simple( image: Image.Image, scale: int = 2, method: str = "lanczos" ) -> Image.Image: """Simple upscaling using PIL""" methods = { "lanczos": Image.LANCZOS, "bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR, "nearest": Image.NEAREST } resample = methods.get(method.lower(), Image.LANCZOS) new_size = (image.width * scale, image.height * scale) return image.resize(new_size, resample=resample) @staticmethod def add_watermark( image: Image.Image, text: str, position: str = "bottom-right", opacity: int = 128 ) -> Image.Image: """Add text watermark to image""" watermark = image.copy() draw = ImageDraw.Draw(watermark, 'RGBA') # Try to load a font try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20) except: font = ImageFont.load_default() # Calculate position bbox = draw.textbbox((0, 0), text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] positions = { "top-left": (10, 10), "top-right": (image.width - text_width - 10, 10), "bottom-left": (10, image.height - text_height - 10), "bottom-right": (image.width - text_width - 10, image.height - text_height - 10), "center": ((image.width - text_width) // 2, (image.height - text_height) // 2) } pos = positions.get(position, positions["bottom-right"]) # Draw with opacity draw.text(pos, text, fill=(255, 255, 255, opacity), font=font) return watermark @staticmethod def create_comparison( images: List[Image.Image], labels: Optional[List[str]] = None, padding: int = 10 ) -> Image.Image: """Create side-by-side comparison of images""" if not images: raise ValueError("No images provided") # Ensure all images have same height max_height = max(img.height for img in images) resized_images = [] for img in images: if img.height != max_height: ratio = max_height / img.height new_width = int(img.width * ratio) img = img.resize((new_width, max_height), Image.LANCZOS) resized_images.append(img) # Calculate total width total_width = sum(img.width for img in resized_images) + padding * (len(resized_images) - 1) # Create comparison image comparison = Image.new('RGB', (total_width, max_height), color='white') x_offset = 0 for i, img in enumerate(resized_images): comparison.paste(img, (x_offset, 0)) # Add label if provided if labels and i < len(labels): draw = ImageDraw.Draw(comparison) draw.text((x_offset + 10, 10), labels[i], fill='white') x_offset += img.width + padding return comparison @staticmethod def get_image_stats(image: Image.Image) -> Dict: """Get statistical information about an image""" img_array = np.array(image) stats = { "size": image.size, "mode": image.mode, "mean_brightness": np.mean(img_array), "std_brightness": np.std(img_array), "min_value": np.min(img_array), "max_value": np.max(img_array) } if len(img_array.shape) == 3: stats["mean_per_channel"] = np.mean(img_array, axis=(0, 1)).tolist() return stats # ============================================================================ # METADATA MANAGEMENT # ============================================================================ class MetadataManager: """Manage image metadata""" @staticmethod def embed_metadata( image: Image.Image, metadata: Union[Dict, GenerationMetadata] ) -> Image.Image: """Embed metadata into image""" from PIL import PngImagePlugin png_info = PngImagePlugin.PngInfo() if isinstance(metadata, GenerationMetadata): metadata = asdict(metadata) for key, value in metadata.items(): png_info.add_text(key, str(value)) return image, png_info @staticmethod def extract_metadata(image_path: str) -> Dict: """Extract metadata from saved image""" image = Image.open(image_path) metadata = {} if hasattr(image, 'text'): metadata = dict(image.text) return metadata @staticmethod def save_metadata_json( metadata: Union[Dict, GenerationMetadata], filepath: str ): """Save metadata to separate JSON file""" if isinstance(metadata, GenerationMetadata): metadata = asdict(metadata) with open(filepath, 'w') as f: json.dump(metadata, f, indent=2) @staticmethod def load_metadata_json(filepath: str) -> Dict: """Load metadata from JSON file""" with open(filepath, 'r') as f: return json.load(f) # ============================================================================ # PERFORMANCE MONITORING # ============================================================================ class PerformanceMonitor: """Monitor and log generation performance""" def __init__(self): self.generation_times = [] self.memory_usage = [] self.start_time = None def start(self): """Start timing""" self.start_time = time.time() def stop(self) -> float: """Stop timing and return elapsed time""" if self.start_time is None: return 0.0 elapsed = time.time() - self.start_time self.generation_times.append(elapsed) self.start_time = None return elapsed def get_gpu_memory(self) -> Dict: """Get current GPU memory usage""" if not torch.cuda.is_available(): return {"available": False} return { "allocated": torch.cuda.memory_allocated() / 1024**3, # GB "reserved": torch.cuda.memory_reserved() / 1024**3, "max_allocated": torch.cuda.max_memory_allocated() / 1024**3 } def get_statistics(self) -> Dict: """Get performance statistics""" if not self.generation_times: return {"no_data": True} return { "total_generations": len(self.generation_times), "total_time": sum(self.generation_times), "average_time": np.mean(self.generation_times), "min_time": min(self.generation_times), "max_time": max(self.generation_times), "std_time": np.std(self.generation_times) } def reset(self): """Reset all statistics""" self.generation_times = [] self.memory_usage = [] self.start_time = None # ============================================================================ # CONFIGURATION MANAGEMENT # ============================================================================ class ConfigManager: """Manage configuration files""" @staticmethod def load_config(filepath: str) -> Dict: """Load configuration from JSON file""" with open(filepath, 'r') as f: return json.load(f) @staticmethod def save_config(config: Dict, filepath: str): """Save configuration to JSON file""" with open(filepath, 'w') as f: json.dump(config, f, indent=2) @staticmethod def create_default_config() -> Dict: """Create default configuration""" return { "model_id": "OpenTrouter/Trouter-Imagine-1", "device": "cuda", "dtype": "float16", "defaults": { "width": 512, "height": 512, "num_inference_steps": 30, "guidance_scale": 7.5 }, "optimization": { "attention_slicing": True, "vae_slicing": True, "xformers": True }, "output": { "format": "png", "quality": 95, "save_metadata": True } } @staticmethod def validate_config(config: Dict) -> Tuple[bool, List[str]]: """Validate configuration""" errors = [] required_keys = ["model_id", "device", "defaults"] for key in required_keys: if key not in config: errors.append(f"Missing required key: {key}") if "device" in config: valid_devices = ["cuda", "cpu", "mps"] if config["device"] not in valid_devices: errors.append(f"Invalid device: {config['device']}") return len(errors) == 0, errors # ============================================================================ # BATCH PROCESSING HELPERS # ============================================================================ class BatchProcessor: """Helper for batch processing operations""" @staticmethod def load_prompts_from_file(filepath: str) -> List[str]: """Load prompts from text file (one per line)""" with open(filepath, 'r', encoding='utf-8') as f: prompts = [line.strip() for line in f if line.strip() and not line.startswith('#')] return prompts @staticmethod def load_prompts_from_json(filepath: str) -> List[Dict]: """Load prompts and configs from JSON file""" with open(filepath, 'r') as f: data = json.load(f) if isinstance(data, list): return data elif isinstance(data, dict) and "prompts" in data: return data["prompts"] else: raise ValueError("Invalid JSON format") @staticmethod def save_batch_results( results: List[Tuple[Image.Image, Dict]], output_dir: str, prefix: str = "batch" ): """Save batch generation results""" output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) for i, (image, metadata) in enumerate(results): # Save image image_file = output_path / f"{prefix}_{i:04d}.png" image.save(image_file) # Save metadata metadata_file = output_path / f"{prefix}_{i:04d}_metadata.json" with open(metadata_file, 'w') as f: json.dump(metadata, f, indent=2) @staticmethod def create_batch_report( results: List[Tuple[Image.Image, Dict]], output_file: str ): """Create a summary report of batch processing""" report = { "total_images": len(results), "timestamp": datetime.now().isoformat(), "images": [] } for i, (_, metadata) in enumerate(results): report["images"].append({ "index": i, "prompt": metadata.get("prompt", ""), "generation_time": metadata.get("generation_time", 0), "parameters": { "width": metadata.get("width", 0), "height": metadata.get("height", 0), "steps": metadata.get("num_inference_steps", 0), "guidance": metadata.get("guidance_scale", 0) } }) # Calculate statistics times = [m.get("generation_time", 0) for _, m in results] if times: report["statistics"] = { "total_time": sum(times), "average_time": np.mean(times), "min_time": min(times), "max_time": max(times) } with open(output_file, 'w') as f: json.dump(report, f, indent=2) # ============================================================================ # FILE MANAGEMENT # ============================================================================ class FileManager: """Utilities for file management""" @staticmethod def create_directory_structure(base_dir: str) -> Dict[str, Path]: """Create organized directory structure""" base = Path(base_dir) dirs = { "outputs": base / "outputs", "metadata": base / "metadata", "configs": base / "configs", "logs": base / "logs", "temp": base / "temp" } for dir_path in dirs.values(): dir_path.mkdir(parents=True, exist_ok=True) return dirs @staticmethod def generate_filename( prompt: str, timestamp: bool = True, max_length: int = 50 ) -> str: """Generate filename from prompt""" # Clean prompt clean = re.sub(r'[^\w\s-]', '', prompt.lower()) clean = re.sub(r'[-\s]+', '_', clean) clean = clean[:max_length] if timestamp: ts = datetime.now().strftime("%Y%m%d_%H%M%S") return f"{ts}_{clean}.png" return f"{clean}.png" @staticmethod def get_file_hash(filepath: str) -> str: """Calculate MD5 hash of file""" hash_md5 = hashlib.md5() with open(filepath, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() @staticmethod def cleanup_temp_files(temp_dir: str, older_than_hours: int = 24): """Clean up temporary files older than specified hours""" temp_path = Path(temp_dir) if not temp_path.exists(): return cutoff_time = time.time() - (older_than_hours * 3600) for file in temp_path.glob("*"): if file.is_file() and file.stat().st_mtime < cutoff_time: file.unlink() logger.info(f"Deleted old temp file: {file}") # ============================================================================ # QUALITY ASSESSMENT # ============================================================================ class QualityAssessor: """Assess image quality""" @staticmethod def calculate_sharpness(image: Image.Image) -> float: """Calculate image sharpness using Laplacian variance""" img_array = np.array(image.convert('L')) laplacian = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]]) # Convolve from scipy import signal filtered = signal.convolve2d(img_array, laplacian, mode='valid') variance = np.var(filtered) return float(variance) @staticmethod def calculate_brightness(image: Image.Image) -> float: """Calculate average brightness""" img_array = np.array(image.convert('L')) return float(np.mean(img_array)) @staticmethod def calculate_contrast(image: Image.Image) -> float: """Calculate image contrast""" img_array = np.array(image.convert('L')) return float(np.std(img_array)) @staticmethod def assess_quality(image: Image.Image) -> Dict: """Comprehensive quality assessment""" return { "sharpness": QualityAssessor.calculate_sharpness(image), "brightness": QualityAssessor.calculate_brightness(image), "contrast": QualityAssessor.calculate_contrast(image), "resolution": f"{image.width}x{image.height}", "aspect_ratio": image.width / image.height } # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ def seed_everything(seed: int): """Set all random seeds for reproducibility""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def get_optimal_resolution( target_pixels: int, aspect_ratio: str = "1:1" ) -> Tuple[int, int]: """ Calculate optimal resolution for target pixel count Args: target_pixels: Target total pixels (e.g., 512*512 = 262144) aspect_ratio: Desired aspect ratio (e.g., "16:9", "4:3", "1:1") Returns: (width, height) tuple """ ratios = { "1:1": (1, 1), "4:3": (4, 3), "3:4": (3, 4), "16:9": (16, 9), "9:16": (9, 16), "3:2": (3, 2), "2:3": (2, 3) } ratio_w, ratio_h = ratios.get(aspect_ratio, (1, 1)) # Calculate dimensions height = int(np.sqrt(target_pixels * ratio_h / ratio_w)) width = int(height * ratio_w / ratio_h) # Round to nearest multiple of 8 width = (width // 8) * 8 height = (height // 8) * 8 return width, height def estimate_generation_time( width: int, height: int, steps: int, device: str = "cuda", gpu_model: str = "RTX 3080" ) -> float: """ Estimate generation time based on parameters Returns: Estimated time in seconds """ # Base time per step (seconds) for different GPUs at 512x512 base_times = { "RTX 4090": 0.04, "RTX 3090": 0.07, "RTX 3080": 0.10, "RTX 2080": 0.15, "M1 Max": 0.25 } base_time = base_times.get(gpu_model, 0.10) # Scale by resolution pixel_factor = (width * height) / (512 * 512) # Estimate estimated = base_time * steps * pixel_factor return estimated # Export main classes and functions __all__ = [ 'GenerationConfig', 'GenerationMetadata', 'PromptEnhancer', 'ImageProcessor', 'MetadataManager', 'PerformanceMonitor', 'ConfigManager', 'BatchProcessor', 'FileManager', 'QualityAssessor', 'seed_everything', 'get_optimal_resolution', 'estimate_generation_time' ]