import gradio as gr import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import io from siren import SIREN from utils import ( get_image_coordinates, image_to_tensor, tensor_to_image, downsample_image, train_siren, compute_psnr, compute_mae, compute_ssim_simple, get_model_cache_path, save_model, load_model ) def super_resolve_image(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache=True, image_name="uploaded"): """Perform super-resolution using SIREN. Args: input_image: PIL Image (high-res ground truth) scale_factor: Upscaling factor (2, 4, or 8) training_steps: Number of training steps hidden_features: Number of hidden units hidden_layers: Number of hidden layers use_cache: Whether to use cached models image_name: Name for cache identification Returns: Tuple of images and metrics """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Get original (ground truth) dimensions gt_image = input_image W_gt, H_gt = gt_image.size # Downsample the image downsampled_image = downsample_image(gt_image, scale_factor) W_low, H_low = downsampled_image.size print(f"Ground truth size: {W_gt}x{H_gt}") print(f"Downsampled size: {W_low}x{H_low}") print(f"Target upscale: {scale_factor}x") # Convert downsampled image to tensor low_res_pixels = image_to_tensor(downsampled_image) low_res_coords = get_image_coordinates(H_low, W_low) # Check cache cache_path = get_model_cache_path( f"{image_name}_{W_gt}x{H_gt}", scale_factor, training_steps, hidden_features, hidden_layers ) # Create SIREN model model = SIREN( in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, out_features=3, outermost_linear=True, first_omega_0=30, hidden_omega_0=30 ) # Try to load from cache losses = [] if use_cache: loaded_model = load_model(model, cache_path) if loaded_model is not None: model = loaded_model print("Using cached model!") # Generate dummy loss curve losses = [0.01] * training_steps # Train if not loaded from cache if not losses: print("Training SIREN model...") model, losses = train_siren( model=model, coords=low_res_coords, pixels=low_res_pixels, num_steps=training_steps, learning_rate=1e-4, device=device ) print("Training complete!") # Save to cache if use_cache: save_model(model, cache_path) # Generate super-resolved image at original resolution model.eval() with torch.no_grad(): high_res_coords = get_image_coordinates(H_gt, W_gt).to(device) super_resolved_pixels = model(high_res_coords) # Convert to image super_resolved_image = tensor_to_image(super_resolved_pixels, H_gt, W_gt) # Compute quality metrics gt_pixels = image_to_tensor(gt_image) psnr = compute_psnr(super_resolved_pixels.cpu(), gt_pixels) mae = compute_mae(super_resolved_pixels.cpu(), gt_pixels) ssim = compute_ssim_simple(super_resolved_pixels.cpu(), gt_pixels) print(f"\nQuality Metrics:") print(f" PSNR: {psnr:.2f} dB") print(f" SSIM: {ssim:.4f}") print(f" MAE: {mae:.4f}") # Create metrics display metrics_text = f""" 📊 Quality Metrics (vs Ground Truth): • PSNR: {psnr:.2f} dB (higher is better, >30 dB is good) • SSIM: {ssim:.4f} (closer to 1.0 is better) • MAE: {mae:.4f} (lower is better) Training completed in {training_steps} steps Final MSE Loss: {losses[-1]:.6f} """ # Create loss plot fig, ax = plt.subplots(figsize=(6, 3)) ax.plot(losses, linewidth=2, color='#2E86AB') ax.set_xlabel('Training Step', fontsize=10) ax.set_ylabel('MSE Loss', fontsize=10) ax.set_title('Training Loss Curve', fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3, linestyle='--') ax.set_facecolor('#f8f9fa') # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=100, facecolor='white') buf.seek(0) loss_plot = Image.open(buf) plt.close() # Return individual images and metrics # Order: downsampled, loss_plot, super_resolved, gt, metrics (matches UI layout) return downsampled_image, loss_plot, super_resolved_image, gt_image, metrics_text # Create Gradio interface with gr.Blocks(title="SIREN Super-Resolution") as demo: gr.Markdown( """ # 🔥 SIREN Super-Resolution Demo Upload a high-resolution image, and watch **SIREN** (Sinusoidal Representation Networks) learn to super-resolve it from an artificially downsampled version. **How it works:** Your image is downsampled → SIREN learns the low-res → Generates high-res → Compare with original! """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📤 Input") input_image = gr.Image( type="pil", label="Upload High-Resolution Image", height=300 ) scale_factor = gr.Radio( choices=[2, 4, 8], value=2, label="Downsampling Scale Factor", info="Higher scale = harder task" ) training_steps = gr.Dropdown( choices=[500, 1000, 1500, 2000, 3000, 4000, 5000], value=2000, label="Training Epochs/Steps", info="More steps = better quality but slower" ) use_cache = gr.Checkbox( value=True, label="Use Model Cache", info="Save/load trained models to avoid retraining" ) with gr.Accordion("⚙️ Advanced Settings", open=False): hidden_features = gr.Slider( minimum=128, maximum=512, value=256, step=64, label="Hidden Features", info="Network width" ) hidden_layers = gr.Slider( minimum=2, maximum=6, value=3, step=1, label="Hidden Layers", info="Network depth" ) run_btn = gr.Button("🚀 Run Super-Resolution", variant="primary", size="lg") with gr.Column(scale=2): gr.Markdown("### 📊 Results & Comparison") with gr.Tabs(): with gr.Tab("📉 Side-by-Side Comparison"): gr.Markdown("**Low-Resolution Input & Training**") with gr.Row(): output_downsampled = gr.Image( label="Downsampled (Input)", type="pil", height=300 ) output_loss_plot = gr.Image( label="Training Loss Curve", type="pil", height=300 ) gr.Markdown("**High-Resolution Comparison**") with gr.Row(): output_super_resolved = gr.Image( label="Super-Resolved (SIREN Prediction)", type="pil", height=300 ) output_ground_truth = gr.Image( label="Ground Truth (Original)", type="pil", height=300 ) with gr.Tab("📈 Quality Metrics"): metrics_display = gr.Textbox( label="Quality Analysis", lines=10, max_lines=15 ) # Examples gr.Markdown("### 📸 Try these examples:") # Wrapper function to handle examples with image names def super_resolve_with_name(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache): # Extract image name from the example path if it's from samples image_name = "uploaded" if hasattr(input_image, 'name') and input_image.name: image_name = input_image.name.split('/')[-1].split('.')[0] return super_resolve_image(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache, image_name) gr.Examples( examples=[ ["samples/cat.jpg", 2, 2000, 256, 3, True], ["samples/landscape.jpg", 4, 3000, 256, 3, True], ["samples/portrait.jpg", 2, 2000, 256, 3, True], ["samples/flower.jpg", 4, 3000, 256, 4, True], ], inputs=[input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache], outputs=[output_downsampled, output_loss_plot, output_super_resolved, output_ground_truth, metrics_display], fn=super_resolve_with_name, cache_examples=False, ) gr.Markdown( """ ### 📚 About SIREN & Metrics **SIREN** uses sine activation functions for representing continuous signals with fine details. **Quality Metrics Explained:** - **PSNR** (Peak Signal-to-Noise Ratio): Measures reconstruction quality. >30 dB is good, >40 dB is excellent. - **SSIM** (Structural Similarity Index): Perceptual quality metric. 1.0 is perfect, >0.9 is very good. - **MAE** (Mean Absolute Error): Average pixel difference. Lower is better. **Tips for Better Results:** - Start with 2x scale for quick testing - Use 3000-5000 steps for 4x and 8x scaling - Enable model cache to avoid retraining identical settings - Higher scale factors need more training steps and network capacity **Reference:** [SIREN Paper](https://arxiv.org/abs/2006.09661) | [Tutorial](https://github.com/nipunbatra/pml-teaching/blob/master/notebooks/siren.ipynb) """ ) # Connect the button run_btn.click( fn=super_resolve_with_name, inputs=[input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache], outputs=[output_downsampled, output_loss_plot, output_super_resolved, output_ground_truth, metrics_display] ) if __name__ == "__main__": demo.launch()