Luke-Bergen commited on
Commit
56fcf2e
·
verified ·
1 Parent(s): b36e7c7

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +897 -0
utils.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Trouter-Imagine-1 Utilities and Helper Functions
4
+ Apache 2.0 License
5
+
6
+ Comprehensive utility module providing:
7
+ - Prompt enhancement and optimization
8
+ - Image post-processing
9
+ - Metadata management
10
+ - Performance monitoring
11
+ - Configuration management
12
+ - Quality assessment
13
+ - Batch processing helpers
14
+ - File management
15
+ - API wrappers
16
+ - Advanced preprocessing
17
+ """
18
+
19
+ import torch
20
+ from PIL import Image, ImageEnhance, ImageFilter, ImageDraw, ImageFont
21
+ import numpy as np
22
+ from typing import List, Dict, Tuple, Optional, Union
23
+ import json
24
+ import os
25
+ import hashlib
26
+ from pathlib import Path
27
+ from datetime import datetime
28
+ import re
29
+ import logging
30
+ from dataclasses import dataclass, asdict
31
+ import time
32
+ from collections import defaultdict
33
+
34
+ # Configure logging
35
+ logging.basicConfig(level=logging.INFO)
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # ============================================================================
40
+ # DATA CLASSES FOR CONFIGURATION
41
+ # ============================================================================
42
+
43
+ @dataclass
44
+ class GenerationConfig:
45
+ """Configuration for image generation"""
46
+ prompt: str
47
+ negative_prompt: str = ""
48
+ width: int = 512
49
+ height: int = 512
50
+ num_inference_steps: int = 30
51
+ guidance_scale: float = 7.5
52
+ seed: Optional[int] = None
53
+ num_images: int = 1
54
+
55
+ def to_dict(self) -> Dict:
56
+ return asdict(self)
57
+
58
+ @classmethod
59
+ def from_dict(cls, data: Dict) -> 'GenerationConfig':
60
+ return cls(**data)
61
+
62
+ def validate(self) -> bool:
63
+ """Validate configuration parameters"""
64
+ if self.width % 8 != 0 or self.height % 8 != 0:
65
+ raise ValueError("Width and height must be multiples of 8")
66
+ if self.num_inference_steps < 1:
67
+ raise ValueError("num_inference_steps must be at least 1")
68
+ if self.guidance_scale < 0:
69
+ raise ValueError("guidance_scale must be positive")
70
+ return True
71
+
72
+
73
+ @dataclass
74
+ class GenerationMetadata:
75
+ """Metadata for generated images"""
76
+ prompt: str
77
+ negative_prompt: str
78
+ model_id: str
79
+ width: int
80
+ height: int
81
+ num_inference_steps: int
82
+ guidance_scale: float
83
+ seed: int
84
+ timestamp: str
85
+ generation_time: float
86
+ scheduler: str = "unknown"
87
+
88
+ def to_json(self) -> str:
89
+ return json.dumps(asdict(self), indent=2)
90
+
91
+ @classmethod
92
+ def from_json(cls, json_str: str) -> 'GenerationMetadata':
93
+ return cls(**json.loads(json_str))
94
+
95
+
96
+ # ============================================================================
97
+ # PROMPT ENHANCEMENT
98
+ # ============================================================================
99
+
100
+ class PromptEnhancer:
101
+ """Enhance and optimize prompts for better generation"""
102
+
103
+ QUALITY_BOOSTERS = [
104
+ "highly detailed",
105
+ "professional",
106
+ "4k",
107
+ "ultra detailed",
108
+ "sharp focus",
109
+ "intricate details"
110
+ ]
111
+
112
+ STYLE_KEYWORDS = {
113
+ "photo": ["photography", "realistic", "photorealistic", "sharp focus"],
114
+ "art": ["digital art", "concept art", "artistic", "detailed"],
115
+ "paint": ["oil painting", "painterly", "brushstrokes", "canvas"],
116
+ "anime": ["anime style", "manga", "cel shaded", "vibrant"],
117
+ "3d": ["3d render", "octane render", "unreal engine", "cgi"]
118
+ }
119
+
120
+ NEGATIVE_DEFAULTS = [
121
+ "blurry", "low quality", "distorted", "deformed",
122
+ "ugly", "bad anatomy", "watermark", "signature"
123
+ ]
124
+
125
+ @staticmethod
126
+ def enhance_prompt(
127
+ prompt: str,
128
+ style: Optional[str] = None,
129
+ add_quality: bool = True,
130
+ add_details: bool = True
131
+ ) -> str:
132
+ """
133
+ Enhance a prompt with quality boosters and style keywords
134
+
135
+ Args:
136
+ prompt: Base prompt
137
+ style: Style to apply (photo, art, paint, anime, 3d)
138
+ add_quality: Add quality boosters
139
+ add_details: Add detail-related keywords
140
+
141
+ Returns:
142
+ Enhanced prompt
143
+ """
144
+ enhanced = prompt.strip()
145
+
146
+ # Add style keywords
147
+ if style and style.lower() in PromptEnhancer.STYLE_KEYWORDS:
148
+ style_words = PromptEnhancer.STYLE_KEYWORDS[style.lower()]
149
+ enhanced += ", " + ", ".join(style_words[:2])
150
+
151
+ # Add quality boosters
152
+ if add_quality:
153
+ quality_words = PromptEnhancer.QUALITY_BOOSTERS[:3]
154
+ enhanced += ", " + ", ".join(quality_words)
155
+
156
+ return enhanced
157
+
158
+ @staticmethod
159
+ def build_negative_prompt(
160
+ base_negative: str = "",
161
+ include_defaults: bool = True,
162
+ subject_type: Optional[str] = None
163
+ ) -> str:
164
+ """
165
+ Build a comprehensive negative prompt
166
+
167
+ Args:
168
+ base_negative: User-provided negative prompt
169
+ include_defaults: Include default negative terms
170
+ subject_type: Type of subject (person, animal, landscape, etc.)
171
+
172
+ Returns:
173
+ Enhanced negative prompt
174
+ """
175
+ negatives = []
176
+
177
+ if base_negative:
178
+ negatives.append(base_negative)
179
+
180
+ if include_defaults:
181
+ negatives.extend(PromptEnhancer.NEGATIVE_DEFAULTS)
182
+
183
+ # Subject-specific negatives
184
+ subject_negatives = {
185
+ "person": ["extra limbs", "extra fingers", "fused fingers", "bad hands"],
186
+ "animal": ["extra legs", "incorrect anatomy", "fused limbs"],
187
+ "face": ["asymmetric eyes", "crossed eyes", "bad teeth"],
188
+ "landscape": ["oversaturated", "underexposed", "poor composition"]
189
+ }
190
+
191
+ if subject_type and subject_type.lower() in subject_negatives:
192
+ negatives.extend(subject_negatives[subject_type.lower()])
193
+
194
+ return ", ".join(negatives)
195
+
196
+ @staticmethod
197
+ def extract_keywords(prompt: str) -> List[str]:
198
+ """Extract important keywords from a prompt"""
199
+ # Remove common words
200
+ stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'with', 'by', 'for'}
201
+ words = prompt.lower().split()
202
+ keywords = [w.strip('.,!?;:') for w in words if w not in stop_words]
203
+ return keywords
204
+
205
+ @staticmethod
206
+ def validate_prompt(prompt: str) -> Tuple[bool, List[str]]:
207
+ """
208
+ Validate a prompt and return warnings
209
+
210
+ Returns:
211
+ (is_valid, list_of_warnings)
212
+ """
213
+ warnings = []
214
+
215
+ if len(prompt.strip()) < 3:
216
+ warnings.append("Prompt is very short, consider adding more detail")
217
+
218
+ if len(prompt) > 500:
219
+ warnings.append("Prompt is very long, may be truncated")
220
+
221
+ # Check for common issues
222
+ if "high quality" in prompt.lower() and "low quality" in prompt.lower():
223
+ warnings.append("Contradictory quality terms detected")
224
+
225
+ # Check for excessive punctuation
226
+ if prompt.count(',') > 20:
227
+ warnings.append("Too many commas, consider simplifying")
228
+
229
+ return len(warnings) == 0, warnings
230
+
231
+
232
+ # ============================================================================
233
+ # IMAGE POST-PROCESSING
234
+ # ============================================================================
235
+
236
+ class ImageProcessor:
237
+ """Post-processing utilities for generated images"""
238
+
239
+ @staticmethod
240
+ def enhance_image(
241
+ image: Image.Image,
242
+ brightness: float = 1.0,
243
+ contrast: float = 1.0,
244
+ saturation: float = 1.0,
245
+ sharpness: float = 1.0
246
+ ) -> Image.Image:
247
+ """
248
+ Enhance image with various adjustments
249
+
250
+ Args:
251
+ image: Input PIL Image
252
+ brightness: Brightness factor (1.0 = no change)
253
+ contrast: Contrast factor
254
+ saturation: Color saturation factor
255
+ sharpness: Sharpness factor
256
+
257
+ Returns:
258
+ Enhanced image
259
+ """
260
+ enhanced = image
261
+
262
+ if brightness != 1.0:
263
+ enhancer = ImageEnhance.Brightness(enhanced)
264
+ enhanced = enhancer.enhance(brightness)
265
+
266
+ if contrast != 1.0:
267
+ enhancer = ImageEnhance.Contrast(enhanced)
268
+ enhanced = enhancer.enhance(contrast)
269
+
270
+ if saturation != 1.0:
271
+ enhancer = ImageEnhance.Color(enhanced)
272
+ enhanced = enhancer.enhance(saturation)
273
+
274
+ if sharpness != 1.0:
275
+ enhancer = ImageEnhance.Sharpness(enhanced)
276
+ enhanced = enhancer.enhance(sharpness)
277
+
278
+ return enhanced
279
+
280
+ @staticmethod
281
+ def apply_filter(
282
+ image: Image.Image,
283
+ filter_type: str = "none"
284
+ ) -> Image.Image:
285
+ """
286
+ Apply various filters to image
287
+
288
+ Args:
289
+ image: Input image
290
+ filter_type: Type of filter (blur, sharpen, edge_enhance, smooth, detail)
291
+
292
+ Returns:
293
+ Filtered image
294
+ """
295
+ filters = {
296
+ "blur": ImageFilter.BLUR,
297
+ "sharpen": ImageFilter.SHARPEN,
298
+ "edge_enhance": ImageFilter.EDGE_ENHANCE,
299
+ "edge_enhance_more": ImageFilter.EDGE_ENHANCE_MORE,
300
+ "smooth": ImageFilter.SMOOTH,
301
+ "smooth_more": ImageFilter.SMOOTH_MORE,
302
+ "detail": ImageFilter.DETAIL
303
+ }
304
+
305
+ if filter_type.lower() in filters:
306
+ return image.filter(filters[filter_type.lower()])
307
+
308
+ return image
309
+
310
+ @staticmethod
311
+ def upscale_simple(
312
+ image: Image.Image,
313
+ scale: int = 2,
314
+ method: str = "lanczos"
315
+ ) -> Image.Image:
316
+ """Simple upscaling using PIL"""
317
+ methods = {
318
+ "lanczos": Image.LANCZOS,
319
+ "bicubic": Image.BICUBIC,
320
+ "bilinear": Image.BILINEAR,
321
+ "nearest": Image.NEAREST
322
+ }
323
+
324
+ resample = methods.get(method.lower(), Image.LANCZOS)
325
+ new_size = (image.width * scale, image.height * scale)
326
+ return image.resize(new_size, resample=resample)
327
+
328
+ @staticmethod
329
+ def add_watermark(
330
+ image: Image.Image,
331
+ text: str,
332
+ position: str = "bottom-right",
333
+ opacity: int = 128
334
+ ) -> Image.Image:
335
+ """Add text watermark to image"""
336
+ watermark = image.copy()
337
+ draw = ImageDraw.Draw(watermark, 'RGBA')
338
+
339
+ # Try to load a font
340
+ try:
341
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
342
+ except:
343
+ font = ImageFont.load_default()
344
+
345
+ # Calculate position
346
+ bbox = draw.textbbox((0, 0), text, font=font)
347
+ text_width = bbox[2] - bbox[0]
348
+ text_height = bbox[3] - bbox[1]
349
+
350
+ positions = {
351
+ "top-left": (10, 10),
352
+ "top-right": (image.width - text_width - 10, 10),
353
+ "bottom-left": (10, image.height - text_height - 10),
354
+ "bottom-right": (image.width - text_width - 10, image.height - text_height - 10),
355
+ "center": ((image.width - text_width) // 2, (image.height - text_height) // 2)
356
+ }
357
+
358
+ pos = positions.get(position, positions["bottom-right"])
359
+
360
+ # Draw with opacity
361
+ draw.text(pos, text, fill=(255, 255, 255, opacity), font=font)
362
+
363
+ return watermark
364
+
365
+ @staticmethod
366
+ def create_comparison(
367
+ images: List[Image.Image],
368
+ labels: Optional[List[str]] = None,
369
+ padding: int = 10
370
+ ) -> Image.Image:
371
+ """Create side-by-side comparison of images"""
372
+ if not images:
373
+ raise ValueError("No images provided")
374
+
375
+ # Ensure all images have same height
376
+ max_height = max(img.height for img in images)
377
+ resized_images = []
378
+
379
+ for img in images:
380
+ if img.height != max_height:
381
+ ratio = max_height / img.height
382
+ new_width = int(img.width * ratio)
383
+ img = img.resize((new_width, max_height), Image.LANCZOS)
384
+ resized_images.append(img)
385
+
386
+ # Calculate total width
387
+ total_width = sum(img.width for img in resized_images) + padding * (len(resized_images) - 1)
388
+
389
+ # Create comparison image
390
+ comparison = Image.new('RGB', (total_width, max_height), color='white')
391
+
392
+ x_offset = 0
393
+ for i, img in enumerate(resized_images):
394
+ comparison.paste(img, (x_offset, 0))
395
+
396
+ # Add label if provided
397
+ if labels and i < len(labels):
398
+ draw = ImageDraw.Draw(comparison)
399
+ draw.text((x_offset + 10, 10), labels[i], fill='white')
400
+
401
+ x_offset += img.width + padding
402
+
403
+ return comparison
404
+
405
+ @staticmethod
406
+ def get_image_stats(image: Image.Image) -> Dict:
407
+ """Get statistical information about an image"""
408
+ img_array = np.array(image)
409
+
410
+ stats = {
411
+ "size": image.size,
412
+ "mode": image.mode,
413
+ "mean_brightness": np.mean(img_array),
414
+ "std_brightness": np.std(img_array),
415
+ "min_value": np.min(img_array),
416
+ "max_value": np.max(img_array)
417
+ }
418
+
419
+ if len(img_array.shape) == 3:
420
+ stats["mean_per_channel"] = np.mean(img_array, axis=(0, 1)).tolist()
421
+
422
+ return stats
423
+
424
+
425
+ # ============================================================================
426
+ # METADATA MANAGEMENT
427
+ # ============================================================================
428
+
429
+ class MetadataManager:
430
+ """Manage image metadata"""
431
+
432
+ @staticmethod
433
+ def embed_metadata(
434
+ image: Image.Image,
435
+ metadata: Union[Dict, GenerationMetadata]
436
+ ) -> Image.Image:
437
+ """Embed metadata into image"""
438
+ from PIL import PngImagePlugin
439
+
440
+ png_info = PngImagePlugin.PngInfo()
441
+
442
+ if isinstance(metadata, GenerationMetadata):
443
+ metadata = asdict(metadata)
444
+
445
+ for key, value in metadata.items():
446
+ png_info.add_text(key, str(value))
447
+
448
+ return image, png_info
449
+
450
+ @staticmethod
451
+ def extract_metadata(image_path: str) -> Dict:
452
+ """Extract metadata from saved image"""
453
+ image = Image.open(image_path)
454
+ metadata = {}
455
+
456
+ if hasattr(image, 'text'):
457
+ metadata = dict(image.text)
458
+
459
+ return metadata
460
+
461
+ @staticmethod
462
+ def save_metadata_json(
463
+ metadata: Union[Dict, GenerationMetadata],
464
+ filepath: str
465
+ ):
466
+ """Save metadata to separate JSON file"""
467
+ if isinstance(metadata, GenerationMetadata):
468
+ metadata = asdict(metadata)
469
+
470
+ with open(filepath, 'w') as f:
471
+ json.dump(metadata, f, indent=2)
472
+
473
+ @staticmethod
474
+ def load_metadata_json(filepath: str) -> Dict:
475
+ """Load metadata from JSON file"""
476
+ with open(filepath, 'r') as f:
477
+ return json.load(f)
478
+
479
+
480
+ # ============================================================================
481
+ # PERFORMANCE MONITORING
482
+ # ============================================================================
483
+
484
+ class PerformanceMonitor:
485
+ """Monitor and log generation performance"""
486
+
487
+ def __init__(self):
488
+ self.generation_times = []
489
+ self.memory_usage = []
490
+ self.start_time = None
491
+
492
+ def start(self):
493
+ """Start timing"""
494
+ self.start_time = time.time()
495
+
496
+ def stop(self) -> float:
497
+ """Stop timing and return elapsed time"""
498
+ if self.start_time is None:
499
+ return 0.0
500
+ elapsed = time.time() - self.start_time
501
+ self.generation_times.append(elapsed)
502
+ self.start_time = None
503
+ return elapsed
504
+
505
+ def get_gpu_memory(self) -> Dict:
506
+ """Get current GPU memory usage"""
507
+ if not torch.cuda.is_available():
508
+ return {"available": False}
509
+
510
+ return {
511
+ "allocated": torch.cuda.memory_allocated() / 1024**3, # GB
512
+ "reserved": torch.cuda.memory_reserved() / 1024**3,
513
+ "max_allocated": torch.cuda.max_memory_allocated() / 1024**3
514
+ }
515
+
516
+ def get_statistics(self) -> Dict:
517
+ """Get performance statistics"""
518
+ if not self.generation_times:
519
+ return {"no_data": True}
520
+
521
+ return {
522
+ "total_generations": len(self.generation_times),
523
+ "total_time": sum(self.generation_times),
524
+ "average_time": np.mean(self.generation_times),
525
+ "min_time": min(self.generation_times),
526
+ "max_time": max(self.generation_times),
527
+ "std_time": np.std(self.generation_times)
528
+ }
529
+
530
+ def reset(self):
531
+ """Reset all statistics"""
532
+ self.generation_times = []
533
+ self.memory_usage = []
534
+ self.start_time = None
535
+
536
+
537
+ # ============================================================================
538
+ # CONFIGURATION MANAGEMENT
539
+ # ============================================================================
540
+
541
+ class ConfigManager:
542
+ """Manage configuration files"""
543
+
544
+ @staticmethod
545
+ def load_config(filepath: str) -> Dict:
546
+ """Load configuration from JSON file"""
547
+ with open(filepath, 'r') as f:
548
+ return json.load(f)
549
+
550
+ @staticmethod
551
+ def save_config(config: Dict, filepath: str):
552
+ """Save configuration to JSON file"""
553
+ with open(filepath, 'w') as f:
554
+ json.dump(config, f, indent=2)
555
+
556
+ @staticmethod
557
+ def create_default_config() -> Dict:
558
+ """Create default configuration"""
559
+ return {
560
+ "model_id": "OpenTrouter/Trouter-Imagine-1",
561
+ "device": "cuda",
562
+ "dtype": "float16",
563
+ "defaults": {
564
+ "width": 512,
565
+ "height": 512,
566
+ "num_inference_steps": 30,
567
+ "guidance_scale": 7.5
568
+ },
569
+ "optimization": {
570
+ "attention_slicing": True,
571
+ "vae_slicing": True,
572
+ "xformers": True
573
+ },
574
+ "output": {
575
+ "format": "png",
576
+ "quality": 95,
577
+ "save_metadata": True
578
+ }
579
+ }
580
+
581
+ @staticmethod
582
+ def validate_config(config: Dict) -> Tuple[bool, List[str]]:
583
+ """Validate configuration"""
584
+ errors = []
585
+
586
+ required_keys = ["model_id", "device", "defaults"]
587
+ for key in required_keys:
588
+ if key not in config:
589
+ errors.append(f"Missing required key: {key}")
590
+
591
+ if "device" in config:
592
+ valid_devices = ["cuda", "cpu", "mps"]
593
+ if config["device"] not in valid_devices:
594
+ errors.append(f"Invalid device: {config['device']}")
595
+
596
+ return len(errors) == 0, errors
597
+
598
+
599
+ # ============================================================================
600
+ # BATCH PROCESSING HELPERS
601
+ # ============================================================================
602
+
603
+ class BatchProcessor:
604
+ """Helper for batch processing operations"""
605
+
606
+ @staticmethod
607
+ def load_prompts_from_file(filepath: str) -> List[str]:
608
+ """Load prompts from text file (one per line)"""
609
+ with open(filepath, 'r', encoding='utf-8') as f:
610
+ prompts = [line.strip() for line in f if line.strip() and not line.startswith('#')]
611
+ return prompts
612
+
613
+ @staticmethod
614
+ def load_prompts_from_json(filepath: str) -> List[Dict]:
615
+ """Load prompts and configs from JSON file"""
616
+ with open(filepath, 'r') as f:
617
+ data = json.load(f)
618
+
619
+ if isinstance(data, list):
620
+ return data
621
+ elif isinstance(data, dict) and "prompts" in data:
622
+ return data["prompts"]
623
+ else:
624
+ raise ValueError("Invalid JSON format")
625
+
626
+ @staticmethod
627
+ def save_batch_results(
628
+ results: List[Tuple[Image.Image, Dict]],
629
+ output_dir: str,
630
+ prefix: str = "batch"
631
+ ):
632
+ """Save batch generation results"""
633
+ output_path = Path(output_dir)
634
+ output_path.mkdir(parents=True, exist_ok=True)
635
+
636
+ for i, (image, metadata) in enumerate(results):
637
+ # Save image
638
+ image_file = output_path / f"{prefix}_{i:04d}.png"
639
+ image.save(image_file)
640
+
641
+ # Save metadata
642
+ metadata_file = output_path / f"{prefix}_{i:04d}_metadata.json"
643
+ with open(metadata_file, 'w') as f:
644
+ json.dump(metadata, f, indent=2)
645
+
646
+ @staticmethod
647
+ def create_batch_report(
648
+ results: List[Tuple[Image.Image, Dict]],
649
+ output_file: str
650
+ ):
651
+ """Create a summary report of batch processing"""
652
+ report = {
653
+ "total_images": len(results),
654
+ "timestamp": datetime.now().isoformat(),
655
+ "images": []
656
+ }
657
+
658
+ for i, (_, metadata) in enumerate(results):
659
+ report["images"].append({
660
+ "index": i,
661
+ "prompt": metadata.get("prompt", ""),
662
+ "generation_time": metadata.get("generation_time", 0),
663
+ "parameters": {
664
+ "width": metadata.get("width", 0),
665
+ "height": metadata.get("height", 0),
666
+ "steps": metadata.get("num_inference_steps", 0),
667
+ "guidance": metadata.get("guidance_scale", 0)
668
+ }
669
+ })
670
+
671
+ # Calculate statistics
672
+ times = [m.get("generation_time", 0) for _, m in results]
673
+ if times:
674
+ report["statistics"] = {
675
+ "total_time": sum(times),
676
+ "average_time": np.mean(times),
677
+ "min_time": min(times),
678
+ "max_time": max(times)
679
+ }
680
+
681
+ with open(output_file, 'w') as f:
682
+ json.dump(report, f, indent=2)
683
+
684
+
685
+ # ============================================================================
686
+ # FILE MANAGEMENT
687
+ # ============================================================================
688
+
689
+ class FileManager:
690
+ """Utilities for file management"""
691
+
692
+ @staticmethod
693
+ def create_directory_structure(base_dir: str) -> Dict[str, Path]:
694
+ """Create organized directory structure"""
695
+ base = Path(base_dir)
696
+
697
+ dirs = {
698
+ "outputs": base / "outputs",
699
+ "metadata": base / "metadata",
700
+ "configs": base / "configs",
701
+ "logs": base / "logs",
702
+ "temp": base / "temp"
703
+ }
704
+
705
+ for dir_path in dirs.values():
706
+ dir_path.mkdir(parents=True, exist_ok=True)
707
+
708
+ return dirs
709
+
710
+ @staticmethod
711
+ def generate_filename(
712
+ prompt: str,
713
+ timestamp: bool = True,
714
+ max_length: int = 50
715
+ ) -> str:
716
+ """Generate filename from prompt"""
717
+ # Clean prompt
718
+ clean = re.sub(r'[^\w\s-]', '', prompt.lower())
719
+ clean = re.sub(r'[-\s]+', '_', clean)
720
+ clean = clean[:max_length]
721
+
722
+ if timestamp:
723
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
724
+ return f"{ts}_{clean}.png"
725
+
726
+ return f"{clean}.png"
727
+
728
+ @staticmethod
729
+ def get_file_hash(filepath: str) -> str:
730
+ """Calculate MD5 hash of file"""
731
+ hash_md5 = hashlib.md5()
732
+ with open(filepath, "rb") as f:
733
+ for chunk in iter(lambda: f.read(4096), b""):
734
+ hash_md5.update(chunk)
735
+ return hash_md5.hexdigest()
736
+
737
+ @staticmethod
738
+ def cleanup_temp_files(temp_dir: str, older_than_hours: int = 24):
739
+ """Clean up temporary files older than specified hours"""
740
+ temp_path = Path(temp_dir)
741
+ if not temp_path.exists():
742
+ return
743
+
744
+ cutoff_time = time.time() - (older_than_hours * 3600)
745
+
746
+ for file in temp_path.glob("*"):
747
+ if file.is_file() and file.stat().st_mtime < cutoff_time:
748
+ file.unlink()
749
+ logger.info(f"Deleted old temp file: {file}")
750
+
751
+
752
+ # ============================================================================
753
+ # QUALITY ASSESSMENT
754
+ # ============================================================================
755
+
756
+ class QualityAssessor:
757
+ """Assess image quality"""
758
+
759
+ @staticmethod
760
+ def calculate_sharpness(image: Image.Image) -> float:
761
+ """Calculate image sharpness using Laplacian variance"""
762
+ img_array = np.array(image.convert('L'))
763
+ laplacian = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]])
764
+
765
+ # Convolve
766
+ from scipy import signal
767
+ filtered = signal.convolve2d(img_array, laplacian, mode='valid')
768
+ variance = np.var(filtered)
769
+
770
+ return float(variance)
771
+
772
+ @staticmethod
773
+ def calculate_brightness(image: Image.Image) -> float:
774
+ """Calculate average brightness"""
775
+ img_array = np.array(image.convert('L'))
776
+ return float(np.mean(img_array))
777
+
778
+ @staticmethod
779
+ def calculate_contrast(image: Image.Image) -> float:
780
+ """Calculate image contrast"""
781
+ img_array = np.array(image.convert('L'))
782
+ return float(np.std(img_array))
783
+
784
+ @staticmethod
785
+ def assess_quality(image: Image.Image) -> Dict:
786
+ """Comprehensive quality assessment"""
787
+ return {
788
+ "sharpness": QualityAssessor.calculate_sharpness(image),
789
+ "brightness": QualityAssessor.calculate_brightness(image),
790
+ "contrast": QualityAssessor.calculate_contrast(image),
791
+ "resolution": f"{image.width}x{image.height}",
792
+ "aspect_ratio": image.width / image.height
793
+ }
794
+
795
+
796
+ # ============================================================================
797
+ # UTILITY FUNCTIONS
798
+ # ============================================================================
799
+
800
+ def seed_everything(seed: int):
801
+ """Set all random seeds for reproducibility"""
802
+ random.seed(seed)
803
+ np.random.seed(seed)
804
+ torch.manual_seed(seed)
805
+ if torch.cuda.is_available():
806
+ torch.cuda.manual_seed(seed)
807
+ torch.cuda.manual_seed_all(seed)
808
+ torch.backends.cudnn.deterministic = True
809
+ torch.backends.cudnn.benchmark = False
810
+
811
+
812
+ def get_optimal_resolution(
813
+ target_pixels: int,
814
+ aspect_ratio: str = "1:1"
815
+ ) -> Tuple[int, int]:
816
+ """
817
+ Calculate optimal resolution for target pixel count
818
+
819
+ Args:
820
+ target_pixels: Target total pixels (e.g., 512*512 = 262144)
821
+ aspect_ratio: Desired aspect ratio (e.g., "16:9", "4:3", "1:1")
822
+
823
+ Returns:
824
+ (width, height) tuple
825
+ """
826
+ ratios = {
827
+ "1:1": (1, 1),
828
+ "4:3": (4, 3),
829
+ "3:4": (3, 4),
830
+ "16:9": (16, 9),
831
+ "9:16": (9, 16),
832
+ "3:2": (3, 2),
833
+ "2:3": (2, 3)
834
+ }
835
+
836
+ ratio_w, ratio_h = ratios.get(aspect_ratio, (1, 1))
837
+
838
+ # Calculate dimensions
839
+ height = int(np.sqrt(target_pixels * ratio_h / ratio_w))
840
+ width = int(height * ratio_w / ratio_h)
841
+
842
+ # Round to nearest multiple of 8
843
+ width = (width // 8) * 8
844
+ height = (height // 8) * 8
845
+
846
+ return width, height
847
+
848
+
849
+ def estimate_generation_time(
850
+ width: int,
851
+ height: int,
852
+ steps: int,
853
+ device: str = "cuda",
854
+ gpu_model: str = "RTX 3080"
855
+ ) -> float:
856
+ """
857
+ Estimate generation time based on parameters
858
+
859
+ Returns:
860
+ Estimated time in seconds
861
+ """
862
+ # Base time per step (seconds) for different GPUs at 512x512
863
+ base_times = {
864
+ "RTX 4090": 0.04,
865
+ "RTX 3090": 0.07,
866
+ "RTX 3080": 0.10,
867
+ "RTX 2080": 0.15,
868
+ "M1 Max": 0.25
869
+ }
870
+
871
+ base_time = base_times.get(gpu_model, 0.10)
872
+
873
+ # Scale by resolution
874
+ pixel_factor = (width * height) / (512 * 512)
875
+
876
+ # Estimate
877
+ estimated = base_time * steps * pixel_factor
878
+
879
+ return estimated
880
+
881
+
882
+ # Export main classes and functions
883
+ __all__ = [
884
+ 'GenerationConfig',
885
+ 'GenerationMetadata',
886
+ 'PromptEnhancer',
887
+ 'ImageProcessor',
888
+ 'MetadataManager',
889
+ 'PerformanceMonitor',
890
+ 'ConfigManager',
891
+ 'BatchProcessor',
892
+ 'FileManager',
893
+ 'QualityAssessor',
894
+ 'seed_everything',
895
+ 'get_optimal_resolution',
896
+ 'estimate_generation_time'
897
+ ]