Trouter-Imagine-1 / utils.py
Luke-Bergen's picture
Create utils.py
56fcf2e verified
#!/usr/bin/env python3
"""
Trouter-Imagine-1 Utilities and Helper Functions
Apache 2.0 License
Comprehensive utility module providing:
- Prompt enhancement and optimization
- Image post-processing
- Metadata management
- Performance monitoring
- Configuration management
- Quality assessment
- Batch processing helpers
- File management
- API wrappers
- Advanced preprocessing
"""
import torch
from PIL import Image, ImageEnhance, ImageFilter, ImageDraw, ImageFont
import numpy as np
from typing import List, Dict, Tuple, Optional, Union
import json
import os
import hashlib
from pathlib import Path
from datetime import datetime
import re
import logging
from dataclasses import dataclass, asdict
import time
from collections import defaultdict
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# DATA CLASSES FOR CONFIGURATION
# ============================================================================
@dataclass
class GenerationConfig:
"""Configuration for image generation"""
prompt: str
negative_prompt: str = ""
width: int = 512
height: int = 512
num_inference_steps: int = 30
guidance_scale: float = 7.5
seed: Optional[int] = None
num_images: int = 1
def to_dict(self) -> Dict:
return asdict(self)
@classmethod
def from_dict(cls, data: Dict) -> 'GenerationConfig':
return cls(**data)
def validate(self) -> bool:
"""Validate configuration parameters"""
if self.width % 8 != 0 or self.height % 8 != 0:
raise ValueError("Width and height must be multiples of 8")
if self.num_inference_steps < 1:
raise ValueError("num_inference_steps must be at least 1")
if self.guidance_scale < 0:
raise ValueError("guidance_scale must be positive")
return True
@dataclass
class GenerationMetadata:
"""Metadata for generated images"""
prompt: str
negative_prompt: str
model_id: str
width: int
height: int
num_inference_steps: int
guidance_scale: float
seed: int
timestamp: str
generation_time: float
scheduler: str = "unknown"
def to_json(self) -> str:
return json.dumps(asdict(self), indent=2)
@classmethod
def from_json(cls, json_str: str) -> 'GenerationMetadata':
return cls(**json.loads(json_str))
# ============================================================================
# PROMPT ENHANCEMENT
# ============================================================================
class PromptEnhancer:
"""Enhance and optimize prompts for better generation"""
QUALITY_BOOSTERS = [
"highly detailed",
"professional",
"4k",
"ultra detailed",
"sharp focus",
"intricate details"
]
STYLE_KEYWORDS = {
"photo": ["photography", "realistic", "photorealistic", "sharp focus"],
"art": ["digital art", "concept art", "artistic", "detailed"],
"paint": ["oil painting", "painterly", "brushstrokes", "canvas"],
"anime": ["anime style", "manga", "cel shaded", "vibrant"],
"3d": ["3d render", "octane render", "unreal engine", "cgi"]
}
NEGATIVE_DEFAULTS = [
"blurry", "low quality", "distorted", "deformed",
"ugly", "bad anatomy", "watermark", "signature"
]
@staticmethod
def enhance_prompt(
prompt: str,
style: Optional[str] = None,
add_quality: bool = True,
add_details: bool = True
) -> str:
"""
Enhance a prompt with quality boosters and style keywords
Args:
prompt: Base prompt
style: Style to apply (photo, art, paint, anime, 3d)
add_quality: Add quality boosters
add_details: Add detail-related keywords
Returns:
Enhanced prompt
"""
enhanced = prompt.strip()
# Add style keywords
if style and style.lower() in PromptEnhancer.STYLE_KEYWORDS:
style_words = PromptEnhancer.STYLE_KEYWORDS[style.lower()]
enhanced += ", " + ", ".join(style_words[:2])
# Add quality boosters
if add_quality:
quality_words = PromptEnhancer.QUALITY_BOOSTERS[:3]
enhanced += ", " + ", ".join(quality_words)
return enhanced
@staticmethod
def build_negative_prompt(
base_negative: str = "",
include_defaults: bool = True,
subject_type: Optional[str] = None
) -> str:
"""
Build a comprehensive negative prompt
Args:
base_negative: User-provided negative prompt
include_defaults: Include default negative terms
subject_type: Type of subject (person, animal, landscape, etc.)
Returns:
Enhanced negative prompt
"""
negatives = []
if base_negative:
negatives.append(base_negative)
if include_defaults:
negatives.extend(PromptEnhancer.NEGATIVE_DEFAULTS)
# Subject-specific negatives
subject_negatives = {
"person": ["extra limbs", "extra fingers", "fused fingers", "bad hands"],
"animal": ["extra legs", "incorrect anatomy", "fused limbs"],
"face": ["asymmetric eyes", "crossed eyes", "bad teeth"],
"landscape": ["oversaturated", "underexposed", "poor composition"]
}
if subject_type and subject_type.lower() in subject_negatives:
negatives.extend(subject_negatives[subject_type.lower()])
return ", ".join(negatives)
@staticmethod
def extract_keywords(prompt: str) -> List[str]:
"""Extract important keywords from a prompt"""
# Remove common words
stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'with', 'by', 'for'}
words = prompt.lower().split()
keywords = [w.strip('.,!?;:') for w in words if w not in stop_words]
return keywords
@staticmethod
def validate_prompt(prompt: str) -> Tuple[bool, List[str]]:
"""
Validate a prompt and return warnings
Returns:
(is_valid, list_of_warnings)
"""
warnings = []
if len(prompt.strip()) < 3:
warnings.append("Prompt is very short, consider adding more detail")
if len(prompt) > 500:
warnings.append("Prompt is very long, may be truncated")
# Check for common issues
if "high quality" in prompt.lower() and "low quality" in prompt.lower():
warnings.append("Contradictory quality terms detected")
# Check for excessive punctuation
if prompt.count(',') > 20:
warnings.append("Too many commas, consider simplifying")
return len(warnings) == 0, warnings
# ============================================================================
# IMAGE POST-PROCESSING
# ============================================================================
class ImageProcessor:
"""Post-processing utilities for generated images"""
@staticmethod
def enhance_image(
image: Image.Image,
brightness: float = 1.0,
contrast: float = 1.0,
saturation: float = 1.0,
sharpness: float = 1.0
) -> Image.Image:
"""
Enhance image with various adjustments
Args:
image: Input PIL Image
brightness: Brightness factor (1.0 = no change)
contrast: Contrast factor
saturation: Color saturation factor
sharpness: Sharpness factor
Returns:
Enhanced image
"""
enhanced = image
if brightness != 1.0:
enhancer = ImageEnhance.Brightness(enhanced)
enhanced = enhancer.enhance(brightness)
if contrast != 1.0:
enhancer = ImageEnhance.Contrast(enhanced)
enhanced = enhancer.enhance(contrast)
if saturation != 1.0:
enhancer = ImageEnhance.Color(enhanced)
enhanced = enhancer.enhance(saturation)
if sharpness != 1.0:
enhancer = ImageEnhance.Sharpness(enhanced)
enhanced = enhancer.enhance(sharpness)
return enhanced
@staticmethod
def apply_filter(
image: Image.Image,
filter_type: str = "none"
) -> Image.Image:
"""
Apply various filters to image
Args:
image: Input image
filter_type: Type of filter (blur, sharpen, edge_enhance, smooth, detail)
Returns:
Filtered image
"""
filters = {
"blur": ImageFilter.BLUR,
"sharpen": ImageFilter.SHARPEN,
"edge_enhance": ImageFilter.EDGE_ENHANCE,
"edge_enhance_more": ImageFilter.EDGE_ENHANCE_MORE,
"smooth": ImageFilter.SMOOTH,
"smooth_more": ImageFilter.SMOOTH_MORE,
"detail": ImageFilter.DETAIL
}
if filter_type.lower() in filters:
return image.filter(filters[filter_type.lower()])
return image
@staticmethod
def upscale_simple(
image: Image.Image,
scale: int = 2,
method: str = "lanczos"
) -> Image.Image:
"""Simple upscaling using PIL"""
methods = {
"lanczos": Image.LANCZOS,
"bicubic": Image.BICUBIC,
"bilinear": Image.BILINEAR,
"nearest": Image.NEAREST
}
resample = methods.get(method.lower(), Image.LANCZOS)
new_size = (image.width * scale, image.height * scale)
return image.resize(new_size, resample=resample)
@staticmethod
def add_watermark(
image: Image.Image,
text: str,
position: str = "bottom-right",
opacity: int = 128
) -> Image.Image:
"""Add text watermark to image"""
watermark = image.copy()
draw = ImageDraw.Draw(watermark, 'RGBA')
# Try to load a font
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
except:
font = ImageFont.load_default()
# Calculate position
bbox = draw.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
positions = {
"top-left": (10, 10),
"top-right": (image.width - text_width - 10, 10),
"bottom-left": (10, image.height - text_height - 10),
"bottom-right": (image.width - text_width - 10, image.height - text_height - 10),
"center": ((image.width - text_width) // 2, (image.height - text_height) // 2)
}
pos = positions.get(position, positions["bottom-right"])
# Draw with opacity
draw.text(pos, text, fill=(255, 255, 255, opacity), font=font)
return watermark
@staticmethod
def create_comparison(
images: List[Image.Image],
labels: Optional[List[str]] = None,
padding: int = 10
) -> Image.Image:
"""Create side-by-side comparison of images"""
if not images:
raise ValueError("No images provided")
# Ensure all images have same height
max_height = max(img.height for img in images)
resized_images = []
for img in images:
if img.height != max_height:
ratio = max_height / img.height
new_width = int(img.width * ratio)
img = img.resize((new_width, max_height), Image.LANCZOS)
resized_images.append(img)
# Calculate total width
total_width = sum(img.width for img in resized_images) + padding * (len(resized_images) - 1)
# Create comparison image
comparison = Image.new('RGB', (total_width, max_height), color='white')
x_offset = 0
for i, img in enumerate(resized_images):
comparison.paste(img, (x_offset, 0))
# Add label if provided
if labels and i < len(labels):
draw = ImageDraw.Draw(comparison)
draw.text((x_offset + 10, 10), labels[i], fill='white')
x_offset += img.width + padding
return comparison
@staticmethod
def get_image_stats(image: Image.Image) -> Dict:
"""Get statistical information about an image"""
img_array = np.array(image)
stats = {
"size": image.size,
"mode": image.mode,
"mean_brightness": np.mean(img_array),
"std_brightness": np.std(img_array),
"min_value": np.min(img_array),
"max_value": np.max(img_array)
}
if len(img_array.shape) == 3:
stats["mean_per_channel"] = np.mean(img_array, axis=(0, 1)).tolist()
return stats
# ============================================================================
# METADATA MANAGEMENT
# ============================================================================
class MetadataManager:
"""Manage image metadata"""
@staticmethod
def embed_metadata(
image: Image.Image,
metadata: Union[Dict, GenerationMetadata]
) -> Image.Image:
"""Embed metadata into image"""
from PIL import PngImagePlugin
png_info = PngImagePlugin.PngInfo()
if isinstance(metadata, GenerationMetadata):
metadata = asdict(metadata)
for key, value in metadata.items():
png_info.add_text(key, str(value))
return image, png_info
@staticmethod
def extract_metadata(image_path: str) -> Dict:
"""Extract metadata from saved image"""
image = Image.open(image_path)
metadata = {}
if hasattr(image, 'text'):
metadata = dict(image.text)
return metadata
@staticmethod
def save_metadata_json(
metadata: Union[Dict, GenerationMetadata],
filepath: str
):
"""Save metadata to separate JSON file"""
if isinstance(metadata, GenerationMetadata):
metadata = asdict(metadata)
with open(filepath, 'w') as f:
json.dump(metadata, f, indent=2)
@staticmethod
def load_metadata_json(filepath: str) -> Dict:
"""Load metadata from JSON file"""
with open(filepath, 'r') as f:
return json.load(f)
# ============================================================================
# PERFORMANCE MONITORING
# ============================================================================
class PerformanceMonitor:
"""Monitor and log generation performance"""
def __init__(self):
self.generation_times = []
self.memory_usage = []
self.start_time = None
def start(self):
"""Start timing"""
self.start_time = time.time()
def stop(self) -> float:
"""Stop timing and return elapsed time"""
if self.start_time is None:
return 0.0
elapsed = time.time() - self.start_time
self.generation_times.append(elapsed)
self.start_time = None
return elapsed
def get_gpu_memory(self) -> Dict:
"""Get current GPU memory usage"""
if not torch.cuda.is_available():
return {"available": False}
return {
"allocated": torch.cuda.memory_allocated() / 1024**3, # GB
"reserved": torch.cuda.memory_reserved() / 1024**3,
"max_allocated": torch.cuda.max_memory_allocated() / 1024**3
}
def get_statistics(self) -> Dict:
"""Get performance statistics"""
if not self.generation_times:
return {"no_data": True}
return {
"total_generations": len(self.generation_times),
"total_time": sum(self.generation_times),
"average_time": np.mean(self.generation_times),
"min_time": min(self.generation_times),
"max_time": max(self.generation_times),
"std_time": np.std(self.generation_times)
}
def reset(self):
"""Reset all statistics"""
self.generation_times = []
self.memory_usage = []
self.start_time = None
# ============================================================================
# CONFIGURATION MANAGEMENT
# ============================================================================
class ConfigManager:
"""Manage configuration files"""
@staticmethod
def load_config(filepath: str) -> Dict:
"""Load configuration from JSON file"""
with open(filepath, 'r') as f:
return json.load(f)
@staticmethod
def save_config(config: Dict, filepath: str):
"""Save configuration to JSON file"""
with open(filepath, 'w') as f:
json.dump(config, f, indent=2)
@staticmethod
def create_default_config() -> Dict:
"""Create default configuration"""
return {
"model_id": "OpenTrouter/Trouter-Imagine-1",
"device": "cuda",
"dtype": "float16",
"defaults": {
"width": 512,
"height": 512,
"num_inference_steps": 30,
"guidance_scale": 7.5
},
"optimization": {
"attention_slicing": True,
"vae_slicing": True,
"xformers": True
},
"output": {
"format": "png",
"quality": 95,
"save_metadata": True
}
}
@staticmethod
def validate_config(config: Dict) -> Tuple[bool, List[str]]:
"""Validate configuration"""
errors = []
required_keys = ["model_id", "device", "defaults"]
for key in required_keys:
if key not in config:
errors.append(f"Missing required key: {key}")
if "device" in config:
valid_devices = ["cuda", "cpu", "mps"]
if config["device"] not in valid_devices:
errors.append(f"Invalid device: {config['device']}")
return len(errors) == 0, errors
# ============================================================================
# BATCH PROCESSING HELPERS
# ============================================================================
class BatchProcessor:
"""Helper for batch processing operations"""
@staticmethod
def load_prompts_from_file(filepath: str) -> List[str]:
"""Load prompts from text file (one per line)"""
with open(filepath, 'r', encoding='utf-8') as f:
prompts = [line.strip() for line in f if line.strip() and not line.startswith('#')]
return prompts
@staticmethod
def load_prompts_from_json(filepath: str) -> List[Dict]:
"""Load prompts and configs from JSON file"""
with open(filepath, 'r') as f:
data = json.load(f)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "prompts" in data:
return data["prompts"]
else:
raise ValueError("Invalid JSON format")
@staticmethod
def save_batch_results(
results: List[Tuple[Image.Image, Dict]],
output_dir: str,
prefix: str = "batch"
):
"""Save batch generation results"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for i, (image, metadata) in enumerate(results):
# Save image
image_file = output_path / f"{prefix}_{i:04d}.png"
image.save(image_file)
# Save metadata
metadata_file = output_path / f"{prefix}_{i:04d}_metadata.json"
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
@staticmethod
def create_batch_report(
results: List[Tuple[Image.Image, Dict]],
output_file: str
):
"""Create a summary report of batch processing"""
report = {
"total_images": len(results),
"timestamp": datetime.now().isoformat(),
"images": []
}
for i, (_, metadata) in enumerate(results):
report["images"].append({
"index": i,
"prompt": metadata.get("prompt", ""),
"generation_time": metadata.get("generation_time", 0),
"parameters": {
"width": metadata.get("width", 0),
"height": metadata.get("height", 0),
"steps": metadata.get("num_inference_steps", 0),
"guidance": metadata.get("guidance_scale", 0)
}
})
# Calculate statistics
times = [m.get("generation_time", 0) for _, m in results]
if times:
report["statistics"] = {
"total_time": sum(times),
"average_time": np.mean(times),
"min_time": min(times),
"max_time": max(times)
}
with open(output_file, 'w') as f:
json.dump(report, f, indent=2)
# ============================================================================
# FILE MANAGEMENT
# ============================================================================
class FileManager:
"""Utilities for file management"""
@staticmethod
def create_directory_structure(base_dir: str) -> Dict[str, Path]:
"""Create organized directory structure"""
base = Path(base_dir)
dirs = {
"outputs": base / "outputs",
"metadata": base / "metadata",
"configs": base / "configs",
"logs": base / "logs",
"temp": base / "temp"
}
for dir_path in dirs.values():
dir_path.mkdir(parents=True, exist_ok=True)
return dirs
@staticmethod
def generate_filename(
prompt: str,
timestamp: bool = True,
max_length: int = 50
) -> str:
"""Generate filename from prompt"""
# Clean prompt
clean = re.sub(r'[^\w\s-]', '', prompt.lower())
clean = re.sub(r'[-\s]+', '_', clean)
clean = clean[:max_length]
if timestamp:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{ts}_{clean}.png"
return f"{clean}.png"
@staticmethod
def get_file_hash(filepath: str) -> str:
"""Calculate MD5 hash of file"""
hash_md5 = hashlib.md5()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
@staticmethod
def cleanup_temp_files(temp_dir: str, older_than_hours: int = 24):
"""Clean up temporary files older than specified hours"""
temp_path = Path(temp_dir)
if not temp_path.exists():
return
cutoff_time = time.time() - (older_than_hours * 3600)
for file in temp_path.glob("*"):
if file.is_file() and file.stat().st_mtime < cutoff_time:
file.unlink()
logger.info(f"Deleted old temp file: {file}")
# ============================================================================
# QUALITY ASSESSMENT
# ============================================================================
class QualityAssessor:
"""Assess image quality"""
@staticmethod
def calculate_sharpness(image: Image.Image) -> float:
"""Calculate image sharpness using Laplacian variance"""
img_array = np.array(image.convert('L'))
laplacian = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]])
# Convolve
from scipy import signal
filtered = signal.convolve2d(img_array, laplacian, mode='valid')
variance = np.var(filtered)
return float(variance)
@staticmethod
def calculate_brightness(image: Image.Image) -> float:
"""Calculate average brightness"""
img_array = np.array(image.convert('L'))
return float(np.mean(img_array))
@staticmethod
def calculate_contrast(image: Image.Image) -> float:
"""Calculate image contrast"""
img_array = np.array(image.convert('L'))
return float(np.std(img_array))
@staticmethod
def assess_quality(image: Image.Image) -> Dict:
"""Comprehensive quality assessment"""
return {
"sharpness": QualityAssessor.calculate_sharpness(image),
"brightness": QualityAssessor.calculate_brightness(image),
"contrast": QualityAssessor.calculate_contrast(image),
"resolution": f"{image.width}x{image.height}",
"aspect_ratio": image.width / image.height
}
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================
def seed_everything(seed: int):
"""Set all random seeds for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_optimal_resolution(
target_pixels: int,
aspect_ratio: str = "1:1"
) -> Tuple[int, int]:
"""
Calculate optimal resolution for target pixel count
Args:
target_pixels: Target total pixels (e.g., 512*512 = 262144)
aspect_ratio: Desired aspect ratio (e.g., "16:9", "4:3", "1:1")
Returns:
(width, height) tuple
"""
ratios = {
"1:1": (1, 1),
"4:3": (4, 3),
"3:4": (3, 4),
"16:9": (16, 9),
"9:16": (9, 16),
"3:2": (3, 2),
"2:3": (2, 3)
}
ratio_w, ratio_h = ratios.get(aspect_ratio, (1, 1))
# Calculate dimensions
height = int(np.sqrt(target_pixels * ratio_h / ratio_w))
width = int(height * ratio_w / ratio_h)
# Round to nearest multiple of 8
width = (width // 8) * 8
height = (height // 8) * 8
return width, height
def estimate_generation_time(
width: int,
height: int,
steps: int,
device: str = "cuda",
gpu_model: str = "RTX 3080"
) -> float:
"""
Estimate generation time based on parameters
Returns:
Estimated time in seconds
"""
# Base time per step (seconds) for different GPUs at 512x512
base_times = {
"RTX 4090": 0.04,
"RTX 3090": 0.07,
"RTX 3080": 0.10,
"RTX 2080": 0.15,
"M1 Max": 0.25
}
base_time = base_times.get(gpu_model, 0.10)
# Scale by resolution
pixel_factor = (width * height) / (512 * 512)
# Estimate
estimated = base_time * steps * pixel_factor
return estimated
# Export main classes and functions
__all__ = [
'GenerationConfig',
'GenerationMetadata',
'PromptEnhancer',
'ImageProcessor',
'MetadataManager',
'PerformanceMonitor',
'ConfigManager',
'BatchProcessor',
'FileManager',
'QualityAssessor',
'seed_everything',
'get_optimal_resolution',
'estimate_generation_time'
]