| """
|
| Byte Dream Utilities
|
| Helper functions for image processing, model management, and optimization
|
| """
|
|
|
| import torch
|
| import numpy as np
|
| from PIL import Image
|
| from pathlib import Path
|
| import hashlib
|
| import json
|
| from typing import Optional, Tuple, List
|
|
|
|
|
| def load_image(image_path: str) -> Image.Image:
|
| """
|
| Load image from file
|
|
|
| Args:
|
| image_path: Path to image file
|
|
|
| Returns:
|
| PIL Image object
|
| """
|
| path = Path(image_path)
|
|
|
| if not path.exists():
|
| raise FileNotFoundError(f"Image not found: {image_path}")
|
|
|
| try:
|
| image = Image.open(path).convert('RGB')
|
| return image
|
| except Exception as e:
|
| raise IOError(f"Error loading image: {e}")
|
|
|
|
|
| def save_image(
|
| image: Image.Image,
|
| output_path: str,
|
| format: str = None,
|
| quality: int = 95,
|
| ):
|
| """
|
| Save image to file
|
|
|
| Args:
|
| image: PIL Image to save
|
| output_path: Output file path
|
| format: Image format (PNG, JPEG, etc.)
|
| quality: JPEG quality (1-100)
|
| """
|
| path = Path(output_path)
|
| path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| if format is None:
|
| format = path.suffix.upper().replace('.', '')
|
| if format == 'JPG':
|
| format = 'JPEG'
|
|
|
|
|
| if format == 'JPEG':
|
| image.save(path, format=format, quality=quality, optimize=True)
|
| else:
|
| image.save(path, format=format, optimize=True)
|
|
|
| print(f"Image saved to: {path}")
|
|
|
|
|
| def resize_image(
|
| image: Image.Image,
|
| width: Optional[int] = None,
|
| height: Optional[int] = None,
|
| maintain_aspect: bool = True,
|
| ) -> Image.Image:
|
| """
|
| Resize image to specified dimensions
|
|
|
| Args:
|
| image: Input image
|
| width: Target width
|
| height: Target height
|
| maintain_aspect: Maintain aspect ratio
|
|
|
| Returns:
|
| Resized PIL Image
|
| """
|
| orig_width, orig_height = image.size
|
|
|
| if width is None and height is None:
|
| return image
|
|
|
| if maintain_aspect:
|
| if width and height:
|
|
|
| ratio = min(width / orig_width, height / orig_height)
|
| new_width = int(orig_width * ratio)
|
| new_height = int(orig_height * ratio)
|
| elif width:
|
| ratio = width / orig_width
|
| new_width = width
|
| new_height = int(orig_height * ratio)
|
| else:
|
| ratio = height / orig_height
|
| new_width = int(orig_width * ratio)
|
| new_height = height
|
| else:
|
| new_width = width if width else orig_width
|
| new_height = height if height else orig_height
|
|
|
| resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| return resized
|
|
|
|
|
| def center_crop(image: Image.Image, width: int, height: int) -> Image.Image:
|
| """
|
| Center crop image to specified dimensions
|
|
|
| Args:
|
| image: Input image
|
| width: Crop width
|
| height: Crop height
|
|
|
| Returns:
|
| Cropped PIL Image
|
| """
|
| orig_width, orig_height = image.size
|
|
|
| left = (orig_width - width) // 2
|
| top = (orig_height - height) // 2
|
| right = left + width
|
| bottom = top + height
|
|
|
| cropped = image.crop((left, top, right, bottom))
|
| return cropped
|
|
|
|
|
| def image_to_tensor(image: Image.Image) -> torch.Tensor:
|
| """
|
| Convert PIL Image to PyTorch tensor
|
|
|
| Args:
|
| image: PIL Image
|
|
|
| Returns:
|
| Normalized tensor in range [-1, 1]
|
| """
|
|
|
| img_array = np.array(image).astype(np.float32)
|
|
|
|
|
| img_array = img_array / 255.0
|
|
|
|
|
| img_array = 2.0 * img_array - 1.0
|
|
|
|
|
| tensor = torch.from_numpy(img_array).permute(2, 0, 1)
|
|
|
| return tensor
|
|
|
|
|
| def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
|
| """
|
| Convert PyTorch tensor to PIL Image
|
|
|
| Args:
|
| tensor: Tensor in range [-1, 1], shape (B, C, H, W) or (C, H, W)
|
|
|
| Returns:
|
| PIL Image
|
| """
|
|
|
| if tensor.dim() == 4:
|
| tensor = tensor[0]
|
|
|
|
|
| img_array = tensor.cpu().numpy().transpose(1, 2, 0)
|
|
|
|
|
| img_array = np.clip(img_array, -1, 1)
|
|
|
|
|
| img_array = ((img_array + 1.0) * 127.5).round().astype(np.uint8)
|
|
|
|
|
| if img_array.shape[2] == 1:
|
| img_array = np.repeat(img_array, 3, axis=2)
|
|
|
| image = Image.fromarray(img_array)
|
| return image
|
|
|
|
|
| def generate_prompt_hash(prompt: str) -> str:
|
| """
|
| Generate unique hash for a prompt
|
|
|
| Args:
|
| prompt: Text prompt
|
|
|
| Returns:
|
| Short hash string
|
| """
|
| hash_object = hashlib.md5(prompt.encode())
|
| return hash_object.hexdigest()[:8]
|
|
|
|
|
| def get_model_statistics(model: torch.nn.Module) -> dict:
|
| """
|
| Get model parameter statistics
|
|
|
| Args:
|
| model: PyTorch model
|
|
|
| Returns:
|
| Dictionary with parameter counts
|
| """
|
| total_params = sum(p.numel() for p in model.parameters())
|
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
| param_size = 0
|
| for param in model.parameters():
|
| param_size += param.numel() * param.element_size()
|
|
|
| buffer_size = 0
|
| for buffer in model.buffers():
|
| buffer_size += buffer.numel() * buffer.element_size()
|
|
|
| size_mb = (param_size + buffer_size) / 1024 ** 2
|
|
|
| stats = {
|
| 'total_parameters': total_params,
|
| 'trainable_parameters': trainable_params,
|
| 'non_trainable_parameters': total_params - trainable_params,
|
| 'model_size_mb': round(size_mb, 2),
|
| }
|
|
|
| return stats
|
|
|
|
|
| def optimize_memory_usage(device: str = "cpu"):
|
| """
|
| Optimize memory usage for inference
|
|
|
| Args:
|
| device: Target device
|
| """
|
| import gc
|
|
|
|
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
|
|
|
|
| gc.collect()
|
|
|
|
|
| if device == "cpu":
|
|
|
| try:
|
| import os
|
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
| except:
|
| pass
|
|
|
| print("Memory optimization applied")
|
|
|
|
|
| def set_seed(seed: int):
|
| """
|
| Set random seed for reproducibility
|
|
|
| Args:
|
| seed: Random seed value
|
| """
|
| torch.manual_seed(seed)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(seed)
|
| np.random.seed(seed)
|
|
|
|
|
| def validate_prompt(prompt: str) -> Tuple[bool, str]:
|
| """
|
| Validate and sanitize prompt
|
|
|
| Args:
|
| prompt: Input prompt
|
|
|
| Returns:
|
| Tuple of (is_valid, message)
|
| """
|
| if not prompt or not prompt.strip():
|
| return False, "Prompt cannot be empty"
|
|
|
| if len(prompt) > 1000:
|
| return False, "Prompt too long (max 1000 characters)"
|
|
|
|
|
| forbidden_terms = []
|
| for term in forbidden_terms:
|
| if term.lower() in prompt.lower():
|
| return False, f"Prompt contains forbidden term: {term}"
|
|
|
| return True, "Valid prompt"
|
|
|
|
|
| def create_image_grid(
|
| images: List[Image.Image],
|
| rows: int = None,
|
| cols: int = None,
|
| ) -> Image.Image:
|
| """
|
| Create a grid of images
|
|
|
| Args:
|
| images: List of PIL Images
|
| rows: Number of rows
|
| cols: Number of columns
|
|
|
| Returns:
|
| Grid image
|
| """
|
| if not images:
|
| raise ValueError("No images provided")
|
|
|
| num_images = len(images)
|
|
|
|
|
| if rows is None and cols is None:
|
| cols = int(np.ceil(np.sqrt(num_images)))
|
| rows = int(np.ceil(num_images / cols))
|
| elif rows is None:
|
| rows = int(np.ceil(num_images / cols))
|
| elif cols is None:
|
| cols = int(np.ceil(num_images / rows))
|
|
|
|
|
| width, height = images[0].size
|
|
|
|
|
| grid_width = cols * width
|
| grid_height = rows * height
|
| grid_image = Image.new('RGB', (grid_width, grid_height), color='white')
|
|
|
|
|
| for i, image in enumerate(images):
|
| row = i // cols
|
| col = i % cols
|
| x = col * width
|
| y = row * height
|
| grid_image.paste(image, (x, y))
|
|
|
| return grid_image
|
|
|
|
|
| def get_device_info() -> dict:
|
| """
|
| Get device information
|
|
|
| Returns:
|
| Dictionary with device info
|
| """
|
| info = {
|
| 'cuda_available': torch.cuda.is_available(),
|
| 'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
| 'cpu_cores': __import__('os').cpu_count(),
|
| }
|
|
|
| if torch.cuda.is_available():
|
| info['current_device'] = torch.cuda.current_device()
|
| info['device_name'] = torch.cuda.get_device_name(0)
|
| info['cuda_version'] = torch.version.cuda
|
|
|
| return info
|
|
|
|
|
| class ProgressTracker:
|
| """Track progress of long-running operations"""
|
|
|
| def __init__(self, total: int, description: str = ""):
|
| self.total = total
|
| self.current = 0
|
| self.description = description
|
|
|
| def update(self, n: int = 1):
|
| """Update progress"""
|
| self.current += n
|
|
|
| def get_progress(self) -> float:
|
| """Get progress percentage"""
|
| return (self.current / self.total) * 100 if self.total > 0 else 0
|
|
|
| def __str__(self):
|
| percent = self.get_progress()
|
| bar_length = 30
|
| filled_length = int(bar_length * self.current // self.total)
|
| bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
| return f"{self.description}: [{bar}] {percent:.1f}% ({self.current}/{self.total})"
|
|
|