import torch import numpy as np from PIL import Image from torchvision import transforms import hashlib import os import pickle def get_image_coordinates(H, W): """Generate normalized coordinate grid for image. Args: H: Image height W: Image width Returns: coords: Tensor of shape (H*W, 2) with normalized coordinates in [-1, 1] """ x = torch.linspace(-1, 1, W) y = torch.linspace(-1, 1, H) # Create meshgrid Y, X = torch.meshgrid(y, x, indexing='ij') # Stack and reshape to (H*W, 2) coords = torch.stack([X, Y], dim=-1).reshape(-1, 2) return coords def image_to_tensor(image): """Convert PIL Image to normalized tensor. Args: image: PIL Image Returns: Tensor of shape (H*W, 3) with values in [0, 1] """ # Convert to RGB if not already if image.mode != 'RGB': image = image.convert('RGB') # Convert to tensor and normalize to [0, 1] img_tensor = transforms.ToTensor()(image) # (C, H, W) img_tensor = img_tensor.permute(1, 2, 0) # (H, W, C) img_tensor = img_tensor.reshape(-1, 3) # (H*W, 3) return img_tensor def tensor_to_image(tensor, H, W): """Convert tensor back to PIL Image. Args: tensor: Tensor of shape (H*W, 3) with values in [0, 1] H: Image height W: Image width Returns: PIL Image """ # Reshape to (H, W, C) img = tensor.reshape(H, W, 3) # Clamp to [0, 1] img = torch.clamp(img, 0, 1) # Convert to numpy and scale to [0, 255] img = (img.cpu().numpy() * 255).astype(np.uint8) # Convert to PIL Image return Image.fromarray(img) def downsample_image(image, scale_factor): """Downsample image by scale_factor. Args: image: PIL Image scale_factor: Downsampling factor (e.g., 2 for half size) Returns: Downsampled PIL Image """ W, H = image.size new_W = W // scale_factor new_H = H // scale_factor return image.resize((new_W, new_H), Image.BICUBIC) def train_siren(model, coords, pixels, num_steps=2000, learning_rate=1e-4, device='cpu'): """Train SIREN model on image. Args: model: SIREN model coords: Coordinate tensor (H*W, 2) pixels: Pixel values tensor (H*W, 3) num_steps: Number of training steps learning_rate: Learning rate device: Device to train on Returns: Trained model and training losses """ model = model.to(device) coords = coords.to(device) pixels = pixels.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) losses = [] for step in range(num_steps): # Forward pass pred_pixels = model(coords) # Compute loss loss = torch.nn.functional.mse_loss(pred_pixels, pixels) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) # Print progress if (step + 1) % 200 == 0: print(f"Step {step + 1}/{num_steps}, Loss: {loss.item():.6f}") return model, losses def compute_psnr(img1, img2): """Compute Peak Signal-to-Noise Ratio between two images. Args: img1: First image tensor (H*W, 3) in [0, 1] img2: Second image tensor (H*W, 3) in [0, 1] Returns: PSNR value in dB """ mse = torch.nn.functional.mse_loss(img1, img2) if mse == 0: return float('inf') psnr = 20 * torch.log10(1.0 / torch.sqrt(mse)) return psnr.item() def compute_mae(img1, img2): """Compute Mean Absolute Error between two images. Args: img1: First image tensor (H*W, 3) in [0, 1] img2: Second image tensor (H*W, 3) in [0, 1] Returns: MAE value """ mae = torch.nn.functional.l1_loss(img1, img2) return mae.item() def compute_ssim_simple(img1, img2, window_size=11): """Compute simplified SSIM between two images. Args: img1: First image tensor (H*W, 3) in [0, 1] img2: Second image tensor (H*W, 3) in [0, 1] window_size: Window size for local statistics Returns: SSIM value in [0, 1] """ # Simplified SSIM - compute channel-wise c1 = 0.01 ** 2 c2 = 0.03 ** 2 mu1 = img1.mean() mu2 = img2.mean() sigma1_sq = ((img1 - mu1) ** 2).mean() sigma2_sq = ((img2 - mu2) ** 2).mean() sigma12 = ((img1 - mu1) * (img2 - mu2)).mean() ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \ ((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2)) return ssim.item() def get_model_cache_path(image_path, scale_factor, training_steps, hidden_features, hidden_layers): """Generate cache path for trained model. Args: image_path: Path to image scale_factor: Upscaling factor training_steps: Number of training steps hidden_features: Network width hidden_layers: Network depth Returns: Cache file path """ cache_dir = "model_cache" os.makedirs(cache_dir, exist_ok=True) # Extract image name from path (without extension) if "/" in image_path: image_name = image_path.split("/")[-1].split(".")[0] else: image_name = image_path.split(".")[0] # Create descriptive filename filename = f"{training_steps}steps_{scale_factor}x_{image_name}_h{hidden_features}_l{hidden_layers}.pkl" return os.path.join(cache_dir, filename) def save_model(model, cache_path): """Save model to cache. Args: model: SIREN model cache_path: Path to save model """ with open(cache_path, 'wb') as f: pickle.dump(model.state_dict(), f) print(f"Model saved to cache: {cache_path}") def load_model(model, cache_path): """Load model from cache. Args: model: SIREN model (architecture must match) cache_path: Path to cached model Returns: Loaded model or None if cache doesn't exist """ if os.path.exists(cache_path): with open(cache_path, 'rb') as f: model.load_state_dict(pickle.load(f)) print(f"Model loaded from cache: {cache_path}") return model return None