Nipun's picture
Complete SIREN super-resolution demo with improvements
691ba3c
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import hashlib
import os
import pickle
def get_image_coordinates(H, W):
"""Generate normalized coordinate grid for image.
Args:
H: Image height
W: Image width
Returns:
coords: Tensor of shape (H*W, 2) with normalized coordinates in [-1, 1]
"""
x = torch.linspace(-1, 1, W)
y = torch.linspace(-1, 1, H)
# Create meshgrid
Y, X = torch.meshgrid(y, x, indexing='ij')
# Stack and reshape to (H*W, 2)
coords = torch.stack([X, Y], dim=-1).reshape(-1, 2)
return coords
def image_to_tensor(image):
"""Convert PIL Image to normalized tensor.
Args:
image: PIL Image
Returns:
Tensor of shape (H*W, 3) with values in [0, 1]
"""
# Convert to RGB if not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Convert to tensor and normalize to [0, 1]
img_tensor = transforms.ToTensor()(image) # (C, H, W)
img_tensor = img_tensor.permute(1, 2, 0) # (H, W, C)
img_tensor = img_tensor.reshape(-1, 3) # (H*W, 3)
return img_tensor
def tensor_to_image(tensor, H, W):
"""Convert tensor back to PIL Image.
Args:
tensor: Tensor of shape (H*W, 3) with values in [0, 1]
H: Image height
W: Image width
Returns:
PIL Image
"""
# Reshape to (H, W, C)
img = tensor.reshape(H, W, 3)
# Clamp to [0, 1]
img = torch.clamp(img, 0, 1)
# Convert to numpy and scale to [0, 255]
img = (img.cpu().numpy() * 255).astype(np.uint8)
# Convert to PIL Image
return Image.fromarray(img)
def downsample_image(image, scale_factor):
"""Downsample image by scale_factor.
Args:
image: PIL Image
scale_factor: Downsampling factor (e.g., 2 for half size)
Returns:
Downsampled PIL Image
"""
W, H = image.size
new_W = W // scale_factor
new_H = H // scale_factor
return image.resize((new_W, new_H), Image.BICUBIC)
def train_siren(model, coords, pixels, num_steps=2000, learning_rate=1e-4, device='cpu'):
"""Train SIREN model on image.
Args:
model: SIREN model
coords: Coordinate tensor (H*W, 2)
pixels: Pixel values tensor (H*W, 3)
num_steps: Number of training steps
learning_rate: Learning rate
device: Device to train on
Returns:
Trained model and training losses
"""
model = model.to(device)
coords = coords.to(device)
pixels = pixels.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
losses = []
for step in range(num_steps):
# Forward pass
pred_pixels = model(coords)
# Compute loss
loss = torch.nn.functional.mse_loss(pred_pixels, pixels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
# Print progress
if (step + 1) % 200 == 0:
print(f"Step {step + 1}/{num_steps}, Loss: {loss.item():.6f}")
return model, losses
def compute_psnr(img1, img2):
"""Compute Peak Signal-to-Noise Ratio between two images.
Args:
img1: First image tensor (H*W, 3) in [0, 1]
img2: Second image tensor (H*W, 3) in [0, 1]
Returns:
PSNR value in dB
"""
mse = torch.nn.functional.mse_loss(img1, img2)
if mse == 0:
return float('inf')
psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
return psnr.item()
def compute_mae(img1, img2):
"""Compute Mean Absolute Error between two images.
Args:
img1: First image tensor (H*W, 3) in [0, 1]
img2: Second image tensor (H*W, 3) in [0, 1]
Returns:
MAE value
"""
mae = torch.nn.functional.l1_loss(img1, img2)
return mae.item()
def compute_ssim_simple(img1, img2, window_size=11):
"""Compute simplified SSIM between two images.
Args:
img1: First image tensor (H*W, 3) in [0, 1]
img2: Second image tensor (H*W, 3) in [0, 1]
window_size: Window size for local statistics
Returns:
SSIM value in [0, 1]
"""
# Simplified SSIM - compute channel-wise
c1 = 0.01 ** 2
c2 = 0.03 ** 2
mu1 = img1.mean()
mu2 = img2.mean()
sigma1_sq = ((img1 - mu1) ** 2).mean()
sigma2_sq = ((img2 - mu2) ** 2).mean()
sigma12 = ((img1 - mu1) * (img2 - mu2)).mean()
ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \
((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2))
return ssim.item()
def get_model_cache_path(image_path, scale_factor, training_steps, hidden_features, hidden_layers):
"""Generate cache path for trained model.
Args:
image_path: Path to image
scale_factor: Upscaling factor
training_steps: Number of training steps
hidden_features: Network width
hidden_layers: Network depth
Returns:
Cache file path
"""
cache_dir = "model_cache"
os.makedirs(cache_dir, exist_ok=True)
# Extract image name from path (without extension)
if "/" in image_path:
image_name = image_path.split("/")[-1].split(".")[0]
else:
image_name = image_path.split(".")[0]
# Create descriptive filename
filename = f"{training_steps}steps_{scale_factor}x_{image_name}_h{hidden_features}_l{hidden_layers}.pkl"
return os.path.join(cache_dir, filename)
def save_model(model, cache_path):
"""Save model to cache.
Args:
model: SIREN model
cache_path: Path to save model
"""
with open(cache_path, 'wb') as f:
pickle.dump(model.state_dict(), f)
print(f"Model saved to cache: {cache_path}")
def load_model(model, cache_path):
"""Load model from cache.
Args:
model: SIREN model (architecture must match)
cache_path: Path to cached model
Returns:
Loaded model or None if cache doesn't exist
"""
if os.path.exists(cache_path):
with open(cache_path, 'rb') as f:
model.load_state_dict(pickle.load(f))
print(f"Model loaded from cache: {cache_path}")
return model
return None