|
|
|
|
|
""" |
|
|
Trouter-Imagine-1 Utilities and Helper Functions |
|
|
Apache 2.0 License |
|
|
|
|
|
Comprehensive utility module providing: |
|
|
- Prompt enhancement and optimization |
|
|
- Image post-processing |
|
|
- Metadata management |
|
|
- Performance monitoring |
|
|
- Configuration management |
|
|
- Quality assessment |
|
|
- Batch processing helpers |
|
|
- File management |
|
|
- API wrappers |
|
|
- Advanced preprocessing |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from PIL import Image, ImageEnhance, ImageFilter, ImageDraw, ImageFont |
|
|
import numpy as np |
|
|
from typing import List, Dict, Tuple, Optional, Union |
|
|
import json |
|
|
import os |
|
|
import hashlib |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import re |
|
|
import logging |
|
|
from dataclasses import dataclass, asdict |
|
|
import time |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationConfig: |
|
|
"""Configuration for image generation""" |
|
|
prompt: str |
|
|
negative_prompt: str = "" |
|
|
width: int = 512 |
|
|
height: int = 512 |
|
|
num_inference_steps: int = 30 |
|
|
guidance_scale: float = 7.5 |
|
|
seed: Optional[int] = None |
|
|
num_images: int = 1 |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return asdict(self) |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict) -> 'GenerationConfig': |
|
|
return cls(**data) |
|
|
|
|
|
def validate(self) -> bool: |
|
|
"""Validate configuration parameters""" |
|
|
if self.width % 8 != 0 or self.height % 8 != 0: |
|
|
raise ValueError("Width and height must be multiples of 8") |
|
|
if self.num_inference_steps < 1: |
|
|
raise ValueError("num_inference_steps must be at least 1") |
|
|
if self.guidance_scale < 0: |
|
|
raise ValueError("guidance_scale must be positive") |
|
|
return True |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationMetadata: |
|
|
"""Metadata for generated images""" |
|
|
prompt: str |
|
|
negative_prompt: str |
|
|
model_id: str |
|
|
width: int |
|
|
height: int |
|
|
num_inference_steps: int |
|
|
guidance_scale: float |
|
|
seed: int |
|
|
timestamp: str |
|
|
generation_time: float |
|
|
scheduler: str = "unknown" |
|
|
|
|
|
def to_json(self) -> str: |
|
|
return json.dumps(asdict(self), indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_json(cls, json_str: str) -> 'GenerationMetadata': |
|
|
return cls(**json.loads(json_str)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PromptEnhancer: |
|
|
"""Enhance and optimize prompts for better generation""" |
|
|
|
|
|
QUALITY_BOOSTERS = [ |
|
|
"highly detailed", |
|
|
"professional", |
|
|
"4k", |
|
|
"ultra detailed", |
|
|
"sharp focus", |
|
|
"intricate details" |
|
|
] |
|
|
|
|
|
STYLE_KEYWORDS = { |
|
|
"photo": ["photography", "realistic", "photorealistic", "sharp focus"], |
|
|
"art": ["digital art", "concept art", "artistic", "detailed"], |
|
|
"paint": ["oil painting", "painterly", "brushstrokes", "canvas"], |
|
|
"anime": ["anime style", "manga", "cel shaded", "vibrant"], |
|
|
"3d": ["3d render", "octane render", "unreal engine", "cgi"] |
|
|
} |
|
|
|
|
|
NEGATIVE_DEFAULTS = [ |
|
|
"blurry", "low quality", "distorted", "deformed", |
|
|
"ugly", "bad anatomy", "watermark", "signature" |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def enhance_prompt( |
|
|
prompt: str, |
|
|
style: Optional[str] = None, |
|
|
add_quality: bool = True, |
|
|
add_details: bool = True |
|
|
) -> str: |
|
|
""" |
|
|
Enhance a prompt with quality boosters and style keywords |
|
|
|
|
|
Args: |
|
|
prompt: Base prompt |
|
|
style: Style to apply (photo, art, paint, anime, 3d) |
|
|
add_quality: Add quality boosters |
|
|
add_details: Add detail-related keywords |
|
|
|
|
|
Returns: |
|
|
Enhanced prompt |
|
|
""" |
|
|
enhanced = prompt.strip() |
|
|
|
|
|
|
|
|
if style and style.lower() in PromptEnhancer.STYLE_KEYWORDS: |
|
|
style_words = PromptEnhancer.STYLE_KEYWORDS[style.lower()] |
|
|
enhanced += ", " + ", ".join(style_words[:2]) |
|
|
|
|
|
|
|
|
if add_quality: |
|
|
quality_words = PromptEnhancer.QUALITY_BOOSTERS[:3] |
|
|
enhanced += ", " + ", ".join(quality_words) |
|
|
|
|
|
return enhanced |
|
|
|
|
|
@staticmethod |
|
|
def build_negative_prompt( |
|
|
base_negative: str = "", |
|
|
include_defaults: bool = True, |
|
|
subject_type: Optional[str] = None |
|
|
) -> str: |
|
|
""" |
|
|
Build a comprehensive negative prompt |
|
|
|
|
|
Args: |
|
|
base_negative: User-provided negative prompt |
|
|
include_defaults: Include default negative terms |
|
|
subject_type: Type of subject (person, animal, landscape, etc.) |
|
|
|
|
|
Returns: |
|
|
Enhanced negative prompt |
|
|
""" |
|
|
negatives = [] |
|
|
|
|
|
if base_negative: |
|
|
negatives.append(base_negative) |
|
|
|
|
|
if include_defaults: |
|
|
negatives.extend(PromptEnhancer.NEGATIVE_DEFAULTS) |
|
|
|
|
|
|
|
|
subject_negatives = { |
|
|
"person": ["extra limbs", "extra fingers", "fused fingers", "bad hands"], |
|
|
"animal": ["extra legs", "incorrect anatomy", "fused limbs"], |
|
|
"face": ["asymmetric eyes", "crossed eyes", "bad teeth"], |
|
|
"landscape": ["oversaturated", "underexposed", "poor composition"] |
|
|
} |
|
|
|
|
|
if subject_type and subject_type.lower() in subject_negatives: |
|
|
negatives.extend(subject_negatives[subject_type.lower()]) |
|
|
|
|
|
return ", ".join(negatives) |
|
|
|
|
|
@staticmethod |
|
|
def extract_keywords(prompt: str) -> List[str]: |
|
|
"""Extract important keywords from a prompt""" |
|
|
|
|
|
stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'with', 'by', 'for'} |
|
|
words = prompt.lower().split() |
|
|
keywords = [w.strip('.,!?;:') for w in words if w not in stop_words] |
|
|
return keywords |
|
|
|
|
|
@staticmethod |
|
|
def validate_prompt(prompt: str) -> Tuple[bool, List[str]]: |
|
|
""" |
|
|
Validate a prompt and return warnings |
|
|
|
|
|
Returns: |
|
|
(is_valid, list_of_warnings) |
|
|
""" |
|
|
warnings = [] |
|
|
|
|
|
if len(prompt.strip()) < 3: |
|
|
warnings.append("Prompt is very short, consider adding more detail") |
|
|
|
|
|
if len(prompt) > 500: |
|
|
warnings.append("Prompt is very long, may be truncated") |
|
|
|
|
|
|
|
|
if "high quality" in prompt.lower() and "low quality" in prompt.lower(): |
|
|
warnings.append("Contradictory quality terms detected") |
|
|
|
|
|
|
|
|
if prompt.count(',') > 20: |
|
|
warnings.append("Too many commas, consider simplifying") |
|
|
|
|
|
return len(warnings) == 0, warnings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageProcessor: |
|
|
"""Post-processing utilities for generated images""" |
|
|
|
|
|
@staticmethod |
|
|
def enhance_image( |
|
|
image: Image.Image, |
|
|
brightness: float = 1.0, |
|
|
contrast: float = 1.0, |
|
|
saturation: float = 1.0, |
|
|
sharpness: float = 1.0 |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Enhance image with various adjustments |
|
|
|
|
|
Args: |
|
|
image: Input PIL Image |
|
|
brightness: Brightness factor (1.0 = no change) |
|
|
contrast: Contrast factor |
|
|
saturation: Color saturation factor |
|
|
sharpness: Sharpness factor |
|
|
|
|
|
Returns: |
|
|
Enhanced image |
|
|
""" |
|
|
enhanced = image |
|
|
|
|
|
if brightness != 1.0: |
|
|
enhancer = ImageEnhance.Brightness(enhanced) |
|
|
enhanced = enhancer.enhance(brightness) |
|
|
|
|
|
if contrast != 1.0: |
|
|
enhancer = ImageEnhance.Contrast(enhanced) |
|
|
enhanced = enhancer.enhance(contrast) |
|
|
|
|
|
if saturation != 1.0: |
|
|
enhancer = ImageEnhance.Color(enhanced) |
|
|
enhanced = enhancer.enhance(saturation) |
|
|
|
|
|
if sharpness != 1.0: |
|
|
enhancer = ImageEnhance.Sharpness(enhanced) |
|
|
enhanced = enhancer.enhance(sharpness) |
|
|
|
|
|
return enhanced |
|
|
|
|
|
@staticmethod |
|
|
def apply_filter( |
|
|
image: Image.Image, |
|
|
filter_type: str = "none" |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Apply various filters to image |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
filter_type: Type of filter (blur, sharpen, edge_enhance, smooth, detail) |
|
|
|
|
|
Returns: |
|
|
Filtered image |
|
|
""" |
|
|
filters = { |
|
|
"blur": ImageFilter.BLUR, |
|
|
"sharpen": ImageFilter.SHARPEN, |
|
|
"edge_enhance": ImageFilter.EDGE_ENHANCE, |
|
|
"edge_enhance_more": ImageFilter.EDGE_ENHANCE_MORE, |
|
|
"smooth": ImageFilter.SMOOTH, |
|
|
"smooth_more": ImageFilter.SMOOTH_MORE, |
|
|
"detail": ImageFilter.DETAIL |
|
|
} |
|
|
|
|
|
if filter_type.lower() in filters: |
|
|
return image.filter(filters[filter_type.lower()]) |
|
|
|
|
|
return image |
|
|
|
|
|
@staticmethod |
|
|
def upscale_simple( |
|
|
image: Image.Image, |
|
|
scale: int = 2, |
|
|
method: str = "lanczos" |
|
|
) -> Image.Image: |
|
|
"""Simple upscaling using PIL""" |
|
|
methods = { |
|
|
"lanczos": Image.LANCZOS, |
|
|
"bicubic": Image.BICUBIC, |
|
|
"bilinear": Image.BILINEAR, |
|
|
"nearest": Image.NEAREST |
|
|
} |
|
|
|
|
|
resample = methods.get(method.lower(), Image.LANCZOS) |
|
|
new_size = (image.width * scale, image.height * scale) |
|
|
return image.resize(new_size, resample=resample) |
|
|
|
|
|
@staticmethod |
|
|
def add_watermark( |
|
|
image: Image.Image, |
|
|
text: str, |
|
|
position: str = "bottom-right", |
|
|
opacity: int = 128 |
|
|
) -> Image.Image: |
|
|
"""Add text watermark to image""" |
|
|
watermark = image.copy() |
|
|
draw = ImageDraw.Draw(watermark, 'RGBA') |
|
|
|
|
|
|
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
|
|
|
bbox = draw.textbbox((0, 0), text, font=font) |
|
|
text_width = bbox[2] - bbox[0] |
|
|
text_height = bbox[3] - bbox[1] |
|
|
|
|
|
positions = { |
|
|
"top-left": (10, 10), |
|
|
"top-right": (image.width - text_width - 10, 10), |
|
|
"bottom-left": (10, image.height - text_height - 10), |
|
|
"bottom-right": (image.width - text_width - 10, image.height - text_height - 10), |
|
|
"center": ((image.width - text_width) // 2, (image.height - text_height) // 2) |
|
|
} |
|
|
|
|
|
pos = positions.get(position, positions["bottom-right"]) |
|
|
|
|
|
|
|
|
draw.text(pos, text, fill=(255, 255, 255, opacity), font=font) |
|
|
|
|
|
return watermark |
|
|
|
|
|
@staticmethod |
|
|
def create_comparison( |
|
|
images: List[Image.Image], |
|
|
labels: Optional[List[str]] = None, |
|
|
padding: int = 10 |
|
|
) -> Image.Image: |
|
|
"""Create side-by-side comparison of images""" |
|
|
if not images: |
|
|
raise ValueError("No images provided") |
|
|
|
|
|
|
|
|
max_height = max(img.height for img in images) |
|
|
resized_images = [] |
|
|
|
|
|
for img in images: |
|
|
if img.height != max_height: |
|
|
ratio = max_height / img.height |
|
|
new_width = int(img.width * ratio) |
|
|
img = img.resize((new_width, max_height), Image.LANCZOS) |
|
|
resized_images.append(img) |
|
|
|
|
|
|
|
|
total_width = sum(img.width for img in resized_images) + padding * (len(resized_images) - 1) |
|
|
|
|
|
|
|
|
comparison = Image.new('RGB', (total_width, max_height), color='white') |
|
|
|
|
|
x_offset = 0 |
|
|
for i, img in enumerate(resized_images): |
|
|
comparison.paste(img, (x_offset, 0)) |
|
|
|
|
|
|
|
|
if labels and i < len(labels): |
|
|
draw = ImageDraw.Draw(comparison) |
|
|
draw.text((x_offset + 10, 10), labels[i], fill='white') |
|
|
|
|
|
x_offset += img.width + padding |
|
|
|
|
|
return comparison |
|
|
|
|
|
@staticmethod |
|
|
def get_image_stats(image: Image.Image) -> Dict: |
|
|
"""Get statistical information about an image""" |
|
|
img_array = np.array(image) |
|
|
|
|
|
stats = { |
|
|
"size": image.size, |
|
|
"mode": image.mode, |
|
|
"mean_brightness": np.mean(img_array), |
|
|
"std_brightness": np.std(img_array), |
|
|
"min_value": np.min(img_array), |
|
|
"max_value": np.max(img_array) |
|
|
} |
|
|
|
|
|
if len(img_array.shape) == 3: |
|
|
stats["mean_per_channel"] = np.mean(img_array, axis=(0, 1)).tolist() |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetadataManager: |
|
|
"""Manage image metadata""" |
|
|
|
|
|
@staticmethod |
|
|
def embed_metadata( |
|
|
image: Image.Image, |
|
|
metadata: Union[Dict, GenerationMetadata] |
|
|
) -> Image.Image: |
|
|
"""Embed metadata into image""" |
|
|
from PIL import PngImagePlugin |
|
|
|
|
|
png_info = PngImagePlugin.PngInfo() |
|
|
|
|
|
if isinstance(metadata, GenerationMetadata): |
|
|
metadata = asdict(metadata) |
|
|
|
|
|
for key, value in metadata.items(): |
|
|
png_info.add_text(key, str(value)) |
|
|
|
|
|
return image, png_info |
|
|
|
|
|
@staticmethod |
|
|
def extract_metadata(image_path: str) -> Dict: |
|
|
"""Extract metadata from saved image""" |
|
|
image = Image.open(image_path) |
|
|
metadata = {} |
|
|
|
|
|
if hasattr(image, 'text'): |
|
|
metadata = dict(image.text) |
|
|
|
|
|
return metadata |
|
|
|
|
|
@staticmethod |
|
|
def save_metadata_json( |
|
|
metadata: Union[Dict, GenerationMetadata], |
|
|
filepath: str |
|
|
): |
|
|
"""Save metadata to separate JSON file""" |
|
|
if isinstance(metadata, GenerationMetadata): |
|
|
metadata = asdict(metadata) |
|
|
|
|
|
with open(filepath, 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
@staticmethod |
|
|
def load_metadata_json(filepath: str) -> Dict: |
|
|
"""Load metadata from JSON file""" |
|
|
with open(filepath, 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PerformanceMonitor: |
|
|
"""Monitor and log generation performance""" |
|
|
|
|
|
def __init__(self): |
|
|
self.generation_times = [] |
|
|
self.memory_usage = [] |
|
|
self.start_time = None |
|
|
|
|
|
def start(self): |
|
|
"""Start timing""" |
|
|
self.start_time = time.time() |
|
|
|
|
|
def stop(self) -> float: |
|
|
"""Stop timing and return elapsed time""" |
|
|
if self.start_time is None: |
|
|
return 0.0 |
|
|
elapsed = time.time() - self.start_time |
|
|
self.generation_times.append(elapsed) |
|
|
self.start_time = None |
|
|
return elapsed |
|
|
|
|
|
def get_gpu_memory(self) -> Dict: |
|
|
"""Get current GPU memory usage""" |
|
|
if not torch.cuda.is_available(): |
|
|
return {"available": False} |
|
|
|
|
|
return { |
|
|
"allocated": torch.cuda.memory_allocated() / 1024**3, |
|
|
"reserved": torch.cuda.memory_reserved() / 1024**3, |
|
|
"max_allocated": torch.cuda.max_memory_allocated() / 1024**3 |
|
|
} |
|
|
|
|
|
def get_statistics(self) -> Dict: |
|
|
"""Get performance statistics""" |
|
|
if not self.generation_times: |
|
|
return {"no_data": True} |
|
|
|
|
|
return { |
|
|
"total_generations": len(self.generation_times), |
|
|
"total_time": sum(self.generation_times), |
|
|
"average_time": np.mean(self.generation_times), |
|
|
"min_time": min(self.generation_times), |
|
|
"max_time": max(self.generation_times), |
|
|
"std_time": np.std(self.generation_times) |
|
|
} |
|
|
|
|
|
def reset(self): |
|
|
"""Reset all statistics""" |
|
|
self.generation_times = [] |
|
|
self.memory_usage = [] |
|
|
self.start_time = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigManager: |
|
|
"""Manage configuration files""" |
|
|
|
|
|
@staticmethod |
|
|
def load_config(filepath: str) -> Dict: |
|
|
"""Load configuration from JSON file""" |
|
|
with open(filepath, 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
@staticmethod |
|
|
def save_config(config: Dict, filepath: str): |
|
|
"""Save configuration to JSON file""" |
|
|
with open(filepath, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
@staticmethod |
|
|
def create_default_config() -> Dict: |
|
|
"""Create default configuration""" |
|
|
return { |
|
|
"model_id": "OpenTrouter/Trouter-Imagine-1", |
|
|
"device": "cuda", |
|
|
"dtype": "float16", |
|
|
"defaults": { |
|
|
"width": 512, |
|
|
"height": 512, |
|
|
"num_inference_steps": 30, |
|
|
"guidance_scale": 7.5 |
|
|
}, |
|
|
"optimization": { |
|
|
"attention_slicing": True, |
|
|
"vae_slicing": True, |
|
|
"xformers": True |
|
|
}, |
|
|
"output": { |
|
|
"format": "png", |
|
|
"quality": 95, |
|
|
"save_metadata": True |
|
|
} |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def validate_config(config: Dict) -> Tuple[bool, List[str]]: |
|
|
"""Validate configuration""" |
|
|
errors = [] |
|
|
|
|
|
required_keys = ["model_id", "device", "defaults"] |
|
|
for key in required_keys: |
|
|
if key not in config: |
|
|
errors.append(f"Missing required key: {key}") |
|
|
|
|
|
if "device" in config: |
|
|
valid_devices = ["cuda", "cpu", "mps"] |
|
|
if config["device"] not in valid_devices: |
|
|
errors.append(f"Invalid device: {config['device']}") |
|
|
|
|
|
return len(errors) == 0, errors |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchProcessor: |
|
|
"""Helper for batch processing operations""" |
|
|
|
|
|
@staticmethod |
|
|
def load_prompts_from_file(filepath: str) -> List[str]: |
|
|
"""Load prompts from text file (one per line)""" |
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
|
prompts = [line.strip() for line in f if line.strip() and not line.startswith('#')] |
|
|
return prompts |
|
|
|
|
|
@staticmethod |
|
|
def load_prompts_from_json(filepath: str) -> List[Dict]: |
|
|
"""Load prompts and configs from JSON file""" |
|
|
with open(filepath, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if isinstance(data, list): |
|
|
return data |
|
|
elif isinstance(data, dict) and "prompts" in data: |
|
|
return data["prompts"] |
|
|
else: |
|
|
raise ValueError("Invalid JSON format") |
|
|
|
|
|
@staticmethod |
|
|
def save_batch_results( |
|
|
results: List[Tuple[Image.Image, Dict]], |
|
|
output_dir: str, |
|
|
prefix: str = "batch" |
|
|
): |
|
|
"""Save batch generation results""" |
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
for i, (image, metadata) in enumerate(results): |
|
|
|
|
|
image_file = output_path / f"{prefix}_{i:04d}.png" |
|
|
image.save(image_file) |
|
|
|
|
|
|
|
|
metadata_file = output_path / f"{prefix}_{i:04d}_metadata.json" |
|
|
with open(metadata_file, 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
@staticmethod |
|
|
def create_batch_report( |
|
|
results: List[Tuple[Image.Image, Dict]], |
|
|
output_file: str |
|
|
): |
|
|
"""Create a summary report of batch processing""" |
|
|
report = { |
|
|
"total_images": len(results), |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"images": [] |
|
|
} |
|
|
|
|
|
for i, (_, metadata) in enumerate(results): |
|
|
report["images"].append({ |
|
|
"index": i, |
|
|
"prompt": metadata.get("prompt", ""), |
|
|
"generation_time": metadata.get("generation_time", 0), |
|
|
"parameters": { |
|
|
"width": metadata.get("width", 0), |
|
|
"height": metadata.get("height", 0), |
|
|
"steps": metadata.get("num_inference_steps", 0), |
|
|
"guidance": metadata.get("guidance_scale", 0) |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
times = [m.get("generation_time", 0) for _, m in results] |
|
|
if times: |
|
|
report["statistics"] = { |
|
|
"total_time": sum(times), |
|
|
"average_time": np.mean(times), |
|
|
"min_time": min(times), |
|
|
"max_time": max(times) |
|
|
} |
|
|
|
|
|
with open(output_file, 'w') as f: |
|
|
json.dump(report, f, indent=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FileManager: |
|
|
"""Utilities for file management""" |
|
|
|
|
|
@staticmethod |
|
|
def create_directory_structure(base_dir: str) -> Dict[str, Path]: |
|
|
"""Create organized directory structure""" |
|
|
base = Path(base_dir) |
|
|
|
|
|
dirs = { |
|
|
"outputs": base / "outputs", |
|
|
"metadata": base / "metadata", |
|
|
"configs": base / "configs", |
|
|
"logs": base / "logs", |
|
|
"temp": base / "temp" |
|
|
} |
|
|
|
|
|
for dir_path in dirs.values(): |
|
|
dir_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
return dirs |
|
|
|
|
|
@staticmethod |
|
|
def generate_filename( |
|
|
prompt: str, |
|
|
timestamp: bool = True, |
|
|
max_length: int = 50 |
|
|
) -> str: |
|
|
"""Generate filename from prompt""" |
|
|
|
|
|
clean = re.sub(r'[^\w\s-]', '', prompt.lower()) |
|
|
clean = re.sub(r'[-\s]+', '_', clean) |
|
|
clean = clean[:max_length] |
|
|
|
|
|
if timestamp: |
|
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
return f"{ts}_{clean}.png" |
|
|
|
|
|
return f"{clean}.png" |
|
|
|
|
|
@staticmethod |
|
|
def get_file_hash(filepath: str) -> str: |
|
|
"""Calculate MD5 hash of file""" |
|
|
hash_md5 = hashlib.md5() |
|
|
with open(filepath, "rb") as f: |
|
|
for chunk in iter(lambda: f.read(4096), b""): |
|
|
hash_md5.update(chunk) |
|
|
return hash_md5.hexdigest() |
|
|
|
|
|
@staticmethod |
|
|
def cleanup_temp_files(temp_dir: str, older_than_hours: int = 24): |
|
|
"""Clean up temporary files older than specified hours""" |
|
|
temp_path = Path(temp_dir) |
|
|
if not temp_path.exists(): |
|
|
return |
|
|
|
|
|
cutoff_time = time.time() - (older_than_hours * 3600) |
|
|
|
|
|
for file in temp_path.glob("*"): |
|
|
if file.is_file() and file.stat().st_mtime < cutoff_time: |
|
|
file.unlink() |
|
|
logger.info(f"Deleted old temp file: {file}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QualityAssessor: |
|
|
"""Assess image quality""" |
|
|
|
|
|
@staticmethod |
|
|
def calculate_sharpness(image: Image.Image) -> float: |
|
|
"""Calculate image sharpness using Laplacian variance""" |
|
|
img_array = np.array(image.convert('L')) |
|
|
laplacian = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]]) |
|
|
|
|
|
|
|
|
from scipy import signal |
|
|
filtered = signal.convolve2d(img_array, laplacian, mode='valid') |
|
|
variance = np.var(filtered) |
|
|
|
|
|
return float(variance) |
|
|
|
|
|
@staticmethod |
|
|
def calculate_brightness(image: Image.Image) -> float: |
|
|
"""Calculate average brightness""" |
|
|
img_array = np.array(image.convert('L')) |
|
|
return float(np.mean(img_array)) |
|
|
|
|
|
@staticmethod |
|
|
def calculate_contrast(image: Image.Image) -> float: |
|
|
"""Calculate image contrast""" |
|
|
img_array = np.array(image.convert('L')) |
|
|
return float(np.std(img_array)) |
|
|
|
|
|
@staticmethod |
|
|
def assess_quality(image: Image.Image) -> Dict: |
|
|
"""Comprehensive quality assessment""" |
|
|
return { |
|
|
"sharpness": QualityAssessor.calculate_sharpness(image), |
|
|
"brightness": QualityAssessor.calculate_brightness(image), |
|
|
"contrast": QualityAssessor.calculate_contrast(image), |
|
|
"resolution": f"{image.width}x{image.height}", |
|
|
"aspect_ratio": image.width / image.height |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seed_everything(seed: int): |
|
|
"""Set all random seeds for reproducibility""" |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
def get_optimal_resolution( |
|
|
target_pixels: int, |
|
|
aspect_ratio: str = "1:1" |
|
|
) -> Tuple[int, int]: |
|
|
""" |
|
|
Calculate optimal resolution for target pixel count |
|
|
|
|
|
Args: |
|
|
target_pixels: Target total pixels (e.g., 512*512 = 262144) |
|
|
aspect_ratio: Desired aspect ratio (e.g., "16:9", "4:3", "1:1") |
|
|
|
|
|
Returns: |
|
|
(width, height) tuple |
|
|
""" |
|
|
ratios = { |
|
|
"1:1": (1, 1), |
|
|
"4:3": (4, 3), |
|
|
"3:4": (3, 4), |
|
|
"16:9": (16, 9), |
|
|
"9:16": (9, 16), |
|
|
"3:2": (3, 2), |
|
|
"2:3": (2, 3) |
|
|
} |
|
|
|
|
|
ratio_w, ratio_h = ratios.get(aspect_ratio, (1, 1)) |
|
|
|
|
|
|
|
|
height = int(np.sqrt(target_pixels * ratio_h / ratio_w)) |
|
|
width = int(height * ratio_w / ratio_h) |
|
|
|
|
|
|
|
|
width = (width // 8) * 8 |
|
|
height = (height // 8) * 8 |
|
|
|
|
|
return width, height |
|
|
|
|
|
|
|
|
def estimate_generation_time( |
|
|
width: int, |
|
|
height: int, |
|
|
steps: int, |
|
|
device: str = "cuda", |
|
|
gpu_model: str = "RTX 3080" |
|
|
) -> float: |
|
|
""" |
|
|
Estimate generation time based on parameters |
|
|
|
|
|
Returns: |
|
|
Estimated time in seconds |
|
|
""" |
|
|
|
|
|
base_times = { |
|
|
"RTX 4090": 0.04, |
|
|
"RTX 3090": 0.07, |
|
|
"RTX 3080": 0.10, |
|
|
"RTX 2080": 0.15, |
|
|
"M1 Max": 0.25 |
|
|
} |
|
|
|
|
|
base_time = base_times.get(gpu_model, 0.10) |
|
|
|
|
|
|
|
|
pixel_factor = (width * height) / (512 * 512) |
|
|
|
|
|
|
|
|
estimated = base_time * steps * pixel_factor |
|
|
|
|
|
return estimated |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'GenerationConfig', |
|
|
'GenerationMetadata', |
|
|
'PromptEnhancer', |
|
|
'ImageProcessor', |
|
|
'MetadataManager', |
|
|
'PerformanceMonitor', |
|
|
'ConfigManager', |
|
|
'BatchProcessor', |
|
|
'FileManager', |
|
|
'QualityAssessor', |
|
|
'seed_everything', |
|
|
'get_optimal_resolution', |
|
|
'estimate_generation_time' |
|
|
] |