Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| import hashlib | |
| import os | |
| import pickle | |
| def get_image_coordinates(H, W): | |
| """Generate normalized coordinate grid for image. | |
| Args: | |
| H: Image height | |
| W: Image width | |
| Returns: | |
| coords: Tensor of shape (H*W, 2) with normalized coordinates in [-1, 1] | |
| """ | |
| x = torch.linspace(-1, 1, W) | |
| y = torch.linspace(-1, 1, H) | |
| # Create meshgrid | |
| Y, X = torch.meshgrid(y, x, indexing='ij') | |
| # Stack and reshape to (H*W, 2) | |
| coords = torch.stack([X, Y], dim=-1).reshape(-1, 2) | |
| return coords | |
| def image_to_tensor(image): | |
| """Convert PIL Image to normalized tensor. | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| Tensor of shape (H*W, 3) with values in [0, 1] | |
| """ | |
| # Convert to RGB if not already | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Convert to tensor and normalize to [0, 1] | |
| img_tensor = transforms.ToTensor()(image) # (C, H, W) | |
| img_tensor = img_tensor.permute(1, 2, 0) # (H, W, C) | |
| img_tensor = img_tensor.reshape(-1, 3) # (H*W, 3) | |
| return img_tensor | |
| def tensor_to_image(tensor, H, W): | |
| """Convert tensor back to PIL Image. | |
| Args: | |
| tensor: Tensor of shape (H*W, 3) with values in [0, 1] | |
| H: Image height | |
| W: Image width | |
| Returns: | |
| PIL Image | |
| """ | |
| # Reshape to (H, W, C) | |
| img = tensor.reshape(H, W, 3) | |
| # Clamp to [0, 1] | |
| img = torch.clamp(img, 0, 1) | |
| # Convert to numpy and scale to [0, 255] | |
| img = (img.cpu().numpy() * 255).astype(np.uint8) | |
| # Convert to PIL Image | |
| return Image.fromarray(img) | |
| def downsample_image(image, scale_factor): | |
| """Downsample image by scale_factor. | |
| Args: | |
| image: PIL Image | |
| scale_factor: Downsampling factor (e.g., 2 for half size) | |
| Returns: | |
| Downsampled PIL Image | |
| """ | |
| W, H = image.size | |
| new_W = W // scale_factor | |
| new_H = H // scale_factor | |
| return image.resize((new_W, new_H), Image.BICUBIC) | |
| def train_siren(model, coords, pixels, num_steps=2000, learning_rate=1e-4, device='cpu'): | |
| """Train SIREN model on image. | |
| Args: | |
| model: SIREN model | |
| coords: Coordinate tensor (H*W, 2) | |
| pixels: Pixel values tensor (H*W, 3) | |
| num_steps: Number of training steps | |
| learning_rate: Learning rate | |
| device: Device to train on | |
| Returns: | |
| Trained model and training losses | |
| """ | |
| model = model.to(device) | |
| coords = coords.to(device) | |
| pixels = pixels.to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| losses = [] | |
| for step in range(num_steps): | |
| # Forward pass | |
| pred_pixels = model(coords) | |
| # Compute loss | |
| loss = torch.nn.functional.mse_loss(pred_pixels, pixels) | |
| # Backward pass | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| losses.append(loss.item()) | |
| # Print progress | |
| if (step + 1) % 200 == 0: | |
| print(f"Step {step + 1}/{num_steps}, Loss: {loss.item():.6f}") | |
| return model, losses | |
| def compute_psnr(img1, img2): | |
| """Compute Peak Signal-to-Noise Ratio between two images. | |
| Args: | |
| img1: First image tensor (H*W, 3) in [0, 1] | |
| img2: Second image tensor (H*W, 3) in [0, 1] | |
| Returns: | |
| PSNR value in dB | |
| """ | |
| mse = torch.nn.functional.mse_loss(img1, img2) | |
| if mse == 0: | |
| return float('inf') | |
| psnr = 20 * torch.log10(1.0 / torch.sqrt(mse)) | |
| return psnr.item() | |
| def compute_mae(img1, img2): | |
| """Compute Mean Absolute Error between two images. | |
| Args: | |
| img1: First image tensor (H*W, 3) in [0, 1] | |
| img2: Second image tensor (H*W, 3) in [0, 1] | |
| Returns: | |
| MAE value | |
| """ | |
| mae = torch.nn.functional.l1_loss(img1, img2) | |
| return mae.item() | |
| def compute_ssim_simple(img1, img2, window_size=11): | |
| """Compute simplified SSIM between two images. | |
| Args: | |
| img1: First image tensor (H*W, 3) in [0, 1] | |
| img2: Second image tensor (H*W, 3) in [0, 1] | |
| window_size: Window size for local statistics | |
| Returns: | |
| SSIM value in [0, 1] | |
| """ | |
| # Simplified SSIM - compute channel-wise | |
| c1 = 0.01 ** 2 | |
| c2 = 0.03 ** 2 | |
| mu1 = img1.mean() | |
| mu2 = img2.mean() | |
| sigma1_sq = ((img1 - mu1) ** 2).mean() | |
| sigma2_sq = ((img2 - mu2) ** 2).mean() | |
| sigma12 = ((img1 - mu1) * (img2 - mu2)).mean() | |
| ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \ | |
| ((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2)) | |
| return ssim.item() | |
| def get_model_cache_path(image_path, scale_factor, training_steps, hidden_features, hidden_layers): | |
| """Generate cache path for trained model. | |
| Args: | |
| image_path: Path to image | |
| scale_factor: Upscaling factor | |
| training_steps: Number of training steps | |
| hidden_features: Network width | |
| hidden_layers: Network depth | |
| Returns: | |
| Cache file path | |
| """ | |
| cache_dir = "model_cache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Extract image name from path (without extension) | |
| if "/" in image_path: | |
| image_name = image_path.split("/")[-1].split(".")[0] | |
| else: | |
| image_name = image_path.split(".")[0] | |
| # Create descriptive filename | |
| filename = f"{training_steps}steps_{scale_factor}x_{image_name}_h{hidden_features}_l{hidden_layers}.pkl" | |
| return os.path.join(cache_dir, filename) | |
| def save_model(model, cache_path): | |
| """Save model to cache. | |
| Args: | |
| model: SIREN model | |
| cache_path: Path to save model | |
| """ | |
| with open(cache_path, 'wb') as f: | |
| pickle.dump(model.state_dict(), f) | |
| print(f"Model saved to cache: {cache_path}") | |
| def load_model(model, cache_path): | |
| """Load model from cache. | |
| Args: | |
| model: SIREN model (architecture must match) | |
| cache_path: Path to cached model | |
| Returns: | |
| Loaded model or None if cache doesn't exist | |
| """ | |
| if os.path.exists(cache_path): | |
| with open(cache_path, 'rb') as f: | |
| model.load_state_dict(pickle.load(f)) | |
| print(f"Model loaded from cache: {cache_path}") | |
| return model | |
| return None | |