Nipun Claude commited on
Commit
691ba3c
·
1 Parent(s): dae75ff

Complete SIREN super-resolution demo with improvements

Browse files

Features:
- SIREN implementation with sine activation layers
- Gradio UI with tabbed interface for better comparison
- Quality metrics: PSNR, SSIM, MAE
- Model caching with descriptive filenames (e.g., 2000steps_2x_cat_h256_l3.pkl)
- Real sample images from Unsplash (cat, landscape, portrait, flower)
- Pre-trained models included for instant results
- Selectable training epochs (500, 1000, 1500, 2000, 3000, 4000, 5000)

UI improvements:
- Low-res input and training loss grouped together
- High-res prediction and ground truth side-by-side
- Separate metrics tab for quality analysis
- Clean, intuitive layout

🤖 Generated with Claude Code

Co-Authored-By: Claude <noreply@anthropic.com>

.gitignore ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+ build/
11
+ develop-eggs/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ lib/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # Test outputs
27
+ test_*.png
28
+ test_*.jpg
29
+ demo_output.jpg
30
+
31
+ # Model cache - KEEP model_cache/ so pre-trained models are committed
32
+ # model_cache/
33
+
34
+ # Gradio
35
+ flagged/
36
+ gradio_cached_examples/
37
+
38
+ # IDEs
39
+ .vscode/
40
+ .idea/
41
+ *.swp
42
+ *.swo
43
+ *~
44
+
45
+ # OS
46
+ .DS_Store
47
+ Thumbs.db
README.md CHANGED
@@ -9,4 +9,159 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # 🔥 SIREN Super-Resolution Demo
13
+
14
+ A Gradio demo showcasing **SIREN** (Sinusoidal Representation Networks) for image super-resolution.
15
+
16
+ ## What is SIREN?
17
+
18
+ SIREN networks use periodic activation functions (sine) instead of traditional ReLU activations, making them exceptionally well-suited for representing continuous signals and capturing fine details in images.
19
+
20
+ **Key advantages:**
21
+ - Smooth, continuous representations
22
+ - Excellent for capturing high-frequency details
23
+ - Can represent images at arbitrary resolutions
24
+ - Implicit neural representation - no upsampling layers needed!
25
+
26
+ ## How This Demo Works
27
+
28
+ 1. **Upload** a high-resolution image (this serves as the ground truth)
29
+ 2. **Downsample** the image artificially by a selected scale factor (2x, 4x, or 8x)
30
+ 3. **Train** SIREN to learn the downsampled image representation
31
+ 4. **Generate** a super-resolved version at the original resolution
32
+ 5. **Compare** the results: downsampled input, SIREN output, and ground truth
33
+
34
+ ## Features
35
+
36
+ - 🎚️ **Multiple scale factors**: 2x, 4x, 8x super-resolution
37
+ - 📊 **Quality metrics**: PSNR, SSIM, and MAE for objective quality assessment
38
+ - 💾 **Model caching**: Save and reuse trained models to avoid retraining
39
+ - 🎨 **Improved UI**: Tabbed interface with side-by-side comparison view
40
+ - 🎛️ **Configurable model**: Adjust hidden layers, features, and training steps
41
+ - 📈 **Training visualization**: Watch the loss curve during training
42
+ - 📸 **Real sample images**: High-quality photos from Unsplash (cat, landscape, portrait, flower)
43
+
44
+ ## Installation
45
+
46
+ ```bash
47
+ # Install dependencies
48
+ pip install -r requirements.txt
49
+
50
+ # Generate sample images (optional - already included)
51
+ python create_samples.py
52
+
53
+ # Run the demo
54
+ python app.py
55
+ ```
56
+
57
+ ## Usage
58
+
59
+ ### Running locally:
60
+
61
+ ```bash
62
+ python app.py
63
+ ```
64
+
65
+ Then open your browser to the URL shown (usually `http://127.0.0.1:7860`)
66
+
67
+ ### Quick test:
68
+
69
+ ```bash
70
+ python test_siren.py
71
+ ```
72
+
73
+ This runs a quick test to verify the SIREN implementation works correctly.
74
+
75
+ ## Files
76
+
77
+ - `app.py` - Main Gradio application
78
+ - `siren.py` - SIREN model implementation
79
+ - `utils.py` - Image processing utilities
80
+ - `create_samples.py` - Script to generate sample images
81
+ - `test_siren.py` - Quick test script
82
+ - `samples/` - Sample images for testing
83
+
84
+ ## Parameters
85
+
86
+ ### Model Architecture
87
+ - **Hidden Features**: Width of the network (128-512)
88
+ - More features = more capacity but slower training
89
+ - **Hidden Layers**: Depth of the network (2-6)
90
+ - More layers = more capacity but slower training
91
+
92
+ ### Training
93
+ - **Training Steps**: Number of optimization steps (500-5000)
94
+ - More steps = better quality but takes longer
95
+ - 2000 steps is a good balance
96
+
97
+ ### Super-Resolution
98
+ - **Scale Factor**: Downsampling/upsampling factor (2x, 4x, 8x)
99
+ - 2x: Easier task, faster training
100
+ - 4x: Moderate difficulty
101
+ - 8x: Challenging, may need more steps
102
+
103
+ ## Example Results
104
+
105
+ The demo shows three outputs:
106
+ 1. **Downsampled (Input)**: The artificially downsampled low-resolution image
107
+ 2. **Super-Resolved (SIREN)**: The SIREN-generated high-resolution output
108
+ 3. **Ground Truth (Original)**: The original high-resolution image for comparison
109
+
110
+ ## References
111
+
112
+ - **Paper**: [Implicit Neural Representations with Periodic Activation Functions (SIREN)](https://arxiv.org/abs/2006.09661)
113
+ - **Project Page**: [https://vsitzmann.github.io/siren/](https://vsitzmann.github.io/siren/)
114
+ - **Notebook Tutorial**: [SIREN Tutorial by Nipun Batra](https://github.com/nipunbatra/pml-teaching/blob/master/notebooks/siren.ipynb)
115
+
116
+ ## Quality Metrics Explained
117
+
118
+ The demo now includes three standard image quality metrics:
119
+
120
+ - **PSNR (Peak Signal-to-Noise Ratio)**: Measures reconstruction quality in dB. Higher is better.
121
+ - \>30 dB: Good quality
122
+ - \>40 dB: Excellent quality
123
+
124
+ - **SSIM (Structural Similarity Index)**: Perceptual quality metric ranging from 0 to 1. Closer to 1.0 is better.
125
+ - \>0.9: Very good quality
126
+ - \>0.95: Excellent quality
127
+
128
+ - **MAE (Mean Absolute Error)**: Average pixel-wise difference. Lower is better.
129
+ - <0.01: Excellent
130
+ - <0.05: Good
131
+
132
+ ## Model Caching
133
+
134
+ Trained models are automatically saved and can be reused:
135
+
136
+ - Models are cached in `model_cache/` directory
137
+ - Cache key includes: image size, scale factor, training steps, and architecture
138
+ - Enable/disable caching with the checkbox in the UI
139
+ - Drastically speeds up repeated experiments with the same settings
140
+
141
+ ## Tips for Best Results
142
+
143
+ 1. **Start with lower scale factors** (2x) for faster experimentation
144
+ 2. **Scale-specific training steps**:
145
+ - 2x: 1500-2000 steps
146
+ - 4x: 3000 steps
147
+ - 8x: 4000-5000 steps
148
+ 3. **For 8x super-resolution**:
149
+ - Use 4000-5000 training steps
150
+ - Increase hidden layers to 4-5
151
+ - Use 512 hidden features
152
+ - Check quality metrics to verify results
153
+ 4. **Use images with rich details** to see SIREN's strength in capturing high-frequency content
154
+ 5. **Enable model cache** to avoid retraining with identical settings
155
+
156
+ ## License
157
+
158
+ This demo is for educational purposes. Please cite the original SIREN paper if you use this in your work:
159
+
160
+ ```bibtex
161
+ @inproceedings{sitzmann2020implicit,
162
+ title={Implicit Neural Representations with Periodic Activation Functions},
163
+ author={Sitzmann, Vincent and Martel, Julien NP and Bergman, Alexander W and Lindell, David B and Wetzstein, Gordon},
164
+ booktitle={Proc. NeurIPS},
165
+ year={2020}
166
+ }
167
+ ```
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import io
7
+
8
+ from siren import SIREN
9
+ from utils import (
10
+ get_image_coordinates,
11
+ image_to_tensor,
12
+ tensor_to_image,
13
+ downsample_image,
14
+ train_siren,
15
+ compute_psnr,
16
+ compute_mae,
17
+ compute_ssim_simple,
18
+ get_model_cache_path,
19
+ save_model,
20
+ load_model
21
+ )
22
+
23
+
24
+ def super_resolve_image(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache=True, image_name="uploaded"):
25
+ """Perform super-resolution using SIREN.
26
+
27
+ Args:
28
+ input_image: PIL Image (high-res ground truth)
29
+ scale_factor: Upscaling factor (2, 4, or 8)
30
+ training_steps: Number of training steps
31
+ hidden_features: Number of hidden units
32
+ hidden_layers: Number of hidden layers
33
+ use_cache: Whether to use cached models
34
+ image_name: Name for cache identification
35
+
36
+ Returns:
37
+ Tuple of images and metrics
38
+ """
39
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
+ print(f"Using device: {device}")
41
+
42
+ # Get original (ground truth) dimensions
43
+ gt_image = input_image
44
+ W_gt, H_gt = gt_image.size
45
+
46
+ # Downsample the image
47
+ downsampled_image = downsample_image(gt_image, scale_factor)
48
+ W_low, H_low = downsampled_image.size
49
+
50
+ print(f"Ground truth size: {W_gt}x{H_gt}")
51
+ print(f"Downsampled size: {W_low}x{H_low}")
52
+ print(f"Target upscale: {scale_factor}x")
53
+
54
+ # Convert downsampled image to tensor
55
+ low_res_pixels = image_to_tensor(downsampled_image)
56
+ low_res_coords = get_image_coordinates(H_low, W_low)
57
+
58
+ # Check cache
59
+ cache_path = get_model_cache_path(
60
+ f"{image_name}_{W_gt}x{H_gt}",
61
+ scale_factor,
62
+ training_steps,
63
+ hidden_features,
64
+ hidden_layers
65
+ )
66
+
67
+ # Create SIREN model
68
+ model = SIREN(
69
+ in_features=2,
70
+ hidden_features=hidden_features,
71
+ hidden_layers=hidden_layers,
72
+ out_features=3,
73
+ outermost_linear=True,
74
+ first_omega_0=30,
75
+ hidden_omega_0=30
76
+ )
77
+
78
+ # Try to load from cache
79
+ losses = []
80
+ if use_cache:
81
+ loaded_model = load_model(model, cache_path)
82
+ if loaded_model is not None:
83
+ model = loaded_model
84
+ print("Using cached model!")
85
+ # Generate dummy loss curve
86
+ losses = [0.01] * training_steps
87
+
88
+ # Train if not loaded from cache
89
+ if not losses:
90
+ print("Training SIREN model...")
91
+ model, losses = train_siren(
92
+ model=model,
93
+ coords=low_res_coords,
94
+ pixels=low_res_pixels,
95
+ num_steps=training_steps,
96
+ learning_rate=1e-4,
97
+ device=device
98
+ )
99
+ print("Training complete!")
100
+
101
+ # Save to cache
102
+ if use_cache:
103
+ save_model(model, cache_path)
104
+
105
+ # Generate super-resolved image at original resolution
106
+ model.eval()
107
+ with torch.no_grad():
108
+ high_res_coords = get_image_coordinates(H_gt, W_gt).to(device)
109
+ super_resolved_pixels = model(high_res_coords)
110
+
111
+ # Convert to image
112
+ super_resolved_image = tensor_to_image(super_resolved_pixels, H_gt, W_gt)
113
+
114
+ # Compute quality metrics
115
+ gt_pixels = image_to_tensor(gt_image)
116
+ psnr = compute_psnr(super_resolved_pixels.cpu(), gt_pixels)
117
+ mae = compute_mae(super_resolved_pixels.cpu(), gt_pixels)
118
+ ssim = compute_ssim_simple(super_resolved_pixels.cpu(), gt_pixels)
119
+
120
+ print(f"\nQuality Metrics:")
121
+ print(f" PSNR: {psnr:.2f} dB")
122
+ print(f" SSIM: {ssim:.4f}")
123
+ print(f" MAE: {mae:.4f}")
124
+
125
+ # Create metrics display
126
+ metrics_text = f"""
127
+ 📊 Quality Metrics (vs Ground Truth):
128
+
129
+ • PSNR: {psnr:.2f} dB (higher is better, >30 dB is good)
130
+ • SSIM: {ssim:.4f} (closer to 1.0 is better)
131
+ • MAE: {mae:.4f} (lower is better)
132
+
133
+ Training completed in {training_steps} steps
134
+ Final MSE Loss: {losses[-1]:.6f}
135
+ """
136
+
137
+ # Create loss plot
138
+ fig, ax = plt.subplots(figsize=(6, 3))
139
+ ax.plot(losses, linewidth=2, color='#2E86AB')
140
+ ax.set_xlabel('Training Step', fontsize=10)
141
+ ax.set_ylabel('MSE Loss', fontsize=10)
142
+ ax.set_title('Training Loss Curve', fontsize=12, fontweight='bold')
143
+ ax.grid(True, alpha=0.3, linestyle='--')
144
+ ax.set_facecolor('#f8f9fa')
145
+
146
+ # Convert plot to image
147
+ buf = io.BytesIO()
148
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=100, facecolor='white')
149
+ buf.seek(0)
150
+ loss_plot = Image.open(buf)
151
+ plt.close()
152
+
153
+ # Return individual images and metrics
154
+ return downsampled_image, super_resolved_image, gt_image, loss_plot, metrics_text
155
+
156
+
157
+ # Create Gradio interface
158
+ with gr.Blocks(title="SIREN Super-Resolution", theme=gr.themes.Soft()) as demo:
159
+ gr.Markdown(
160
+ """
161
+ # 🔥 SIREN Super-Resolution Demo
162
+
163
+ Upload a high-resolution image, and watch **SIREN** (Sinusoidal Representation Networks)
164
+ learn to super-resolve it from an artificially downsampled version.
165
+
166
+ **How it works:** Your image is downsampled → SIREN learns the low-res → Generates high-res → Compare with original!
167
+ """
168
+ )
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ gr.Markdown("### 📤 Input")
173
+ input_image = gr.Image(
174
+ type="pil",
175
+ label="Upload High-Resolution Image",
176
+ height=300
177
+ )
178
+
179
+ scale_factor = gr.Radio(
180
+ choices=[2, 4, 8],
181
+ value=2,
182
+ label="Downsampling Scale Factor",
183
+ info="Higher scale = harder task"
184
+ )
185
+
186
+ training_steps = gr.Dropdown(
187
+ choices=[500, 1000, 1500, 2000, 3000, 4000, 5000],
188
+ value=2000,
189
+ label="Training Epochs/Steps",
190
+ info="More steps = better quality but slower"
191
+ )
192
+
193
+ use_cache = gr.Checkbox(
194
+ value=True,
195
+ label="Use Model Cache",
196
+ info="Save/load trained models to avoid retraining"
197
+ )
198
+
199
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
200
+ hidden_features = gr.Slider(
201
+ minimum=128,
202
+ maximum=512,
203
+ value=256,
204
+ step=64,
205
+ label="Hidden Features",
206
+ info="Network width"
207
+ )
208
+
209
+ hidden_layers = gr.Slider(
210
+ minimum=2,
211
+ maximum=6,
212
+ value=3,
213
+ step=1,
214
+ label="Hidden Layers",
215
+ info="Network depth"
216
+ )
217
+
218
+ run_btn = gr.Button("🚀 Run Super-Resolution", variant="primary", size="lg")
219
+
220
+ with gr.Column(scale=2):
221
+ gr.Markdown("### 📊 Results & Comparison")
222
+
223
+ with gr.Tabs():
224
+ with gr.Tab("📉 Side-by-Side Comparison"):
225
+ gr.Markdown("**Low-Resolution Input & Training**")
226
+ with gr.Row():
227
+ output_downsampled = gr.Image(
228
+ label="Downsampled (Input)",
229
+ type="pil",
230
+ height=300
231
+ )
232
+ output_loss_plot = gr.Image(
233
+ label="Training Loss Curve",
234
+ type="pil",
235
+ height=300
236
+ )
237
+
238
+ gr.Markdown("**High-Resolution Comparison**")
239
+ with gr.Row():
240
+ output_super_resolved = gr.Image(
241
+ label="Super-Resolved (SIREN Prediction)",
242
+ type="pil",
243
+ height=300
244
+ )
245
+ output_ground_truth = gr.Image(
246
+ label="Ground Truth (Original)",
247
+ type="pil",
248
+ height=300
249
+ )
250
+
251
+ with gr.Tab("📈 Quality Metrics"):
252
+ metrics_display = gr.Textbox(
253
+ label="Quality Analysis",
254
+ lines=10,
255
+ max_lines=15
256
+ )
257
+
258
+ # Examples
259
+ gr.Markdown("### 📸 Try these examples:")
260
+
261
+ # Wrapper function to handle examples with image names
262
+ def super_resolve_with_name(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache):
263
+ # Extract image name from the example path if it's from samples
264
+ image_name = "uploaded"
265
+ if hasattr(input_image, 'name') and input_image.name:
266
+ image_name = input_image.name.split('/')[-1].split('.')[0]
267
+ return super_resolve_image(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache, image_name)
268
+
269
+ gr.Examples(
270
+ examples=[
271
+ ["samples/cat.jpg", 2, 2000, 256, 3, True],
272
+ ["samples/landscape.jpg", 4, 3000, 256, 3, True],
273
+ ["samples/portrait.jpg", 2, 2000, 256, 3, True],
274
+ ["samples/flower.jpg", 4, 3000, 256, 4, True],
275
+ ],
276
+ inputs=[input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache],
277
+ outputs=[output_downsampled, output_loss_plot, output_super_resolved, output_ground_truth, metrics_display],
278
+ fn=super_resolve_with_name,
279
+ cache_examples=False,
280
+ )
281
+
282
+ gr.Markdown(
283
+ """
284
+ ### 📚 About SIREN & Metrics
285
+
286
+ **SIREN** uses sine activation functions for representing continuous signals with fine details.
287
+
288
+ **Quality Metrics Explained:**
289
+ - **PSNR** (Peak Signal-to-Noise Ratio): Measures reconstruction quality. >30 dB is good, >40 dB is excellent.
290
+ - **SSIM** (Structural Similarity Index): Perceptual quality metric. 1.0 is perfect, >0.9 is very good.
291
+ - **MAE** (Mean Absolute Error): Average pixel difference. Lower is better.
292
+
293
+ **Tips for Better Results:**
294
+ - Start with 2x scale for quick testing
295
+ - Use 3000-5000 steps for 4x and 8x scaling
296
+ - Enable model cache to avoid retraining identical settings
297
+ - Higher scale factors need more training steps and network capacity
298
+
299
+ **Reference:** [SIREN Paper](https://arxiv.org/abs/2006.09661) |
300
+ [Tutorial](https://github.com/nipunbatra/pml-teaching/blob/master/notebooks/siren.ipynb)
301
+ """
302
+ )
303
+
304
+ # Connect the button
305
+ run_btn.click(
306
+ fn=super_resolve_with_name,
307
+ inputs=[input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache],
308
+ outputs=[output_downsampled, output_loss_plot, output_super_resolved, output_ground_truth, metrics_display]
309
+ )
310
+
311
+
312
+ if __name__ == "__main__":
313
+ demo.launch()
create_samples.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate sample images for SIREN super-resolution demo."""
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import os
5
+
6
+
7
+ def create_gradient_image(size=(512, 512)):
8
+ """Create a colorful gradient image."""
9
+ width, height = size
10
+ img = np.zeros((height, width, 3), dtype=np.uint8)
11
+
12
+ for y in range(height):
13
+ for x in range(width):
14
+ img[y, x, 0] = int(255 * x / width) # Red gradient
15
+ img[y, x, 1] = int(255 * y / height) # Green gradient
16
+ img[y, x, 2] = int(255 * (1 - x / width) * (1 - y / height)) # Blue
17
+
18
+ return Image.fromarray(img)
19
+
20
+
21
+ def create_pattern_image(size=(512, 512)):
22
+ """Create an image with geometric patterns."""
23
+ width, height = size
24
+ img = Image.new('RGB', (width, height), 'white')
25
+ draw = ImageDraw.Draw(img)
26
+
27
+ # Draw concentric circles
28
+ center_x, center_y = width // 2, height // 2
29
+ colors = ['red', 'orange', 'yellow', 'green', 'blue', 'purple']
30
+
31
+ for i, color in enumerate(colors):
32
+ radius = (len(colors) - i) * 40
33
+ draw.ellipse(
34
+ [center_x - radius, center_y - radius,
35
+ center_x + radius, center_y + radius],
36
+ outline=color,
37
+ width=3
38
+ )
39
+
40
+ # Draw grid
41
+ for i in range(0, width, 50):
42
+ draw.line([(i, 0), (i, height)], fill='lightgray', width=1)
43
+ for i in range(0, height, 50):
44
+ draw.line([(0, i), (width, i)], fill='lightgray', width=1)
45
+
46
+ return img
47
+
48
+
49
+ def create_checkerboard_image(size=(512, 512), square_size=32):
50
+ """Create a checkerboard pattern with gradients."""
51
+ width, height = size
52
+ img = Image.new('RGB', (width, height))
53
+ pixels = img.load()
54
+
55
+ for y in range(height):
56
+ for x in range(width):
57
+ square_x = x // square_size
58
+ square_y = y // square_size
59
+
60
+ # Checkerboard pattern
61
+ if (square_x + square_y) % 2 == 0:
62
+ # Light square with gradient
63
+ intensity = int(200 + 55 * (x % square_size) / square_size)
64
+ pixels[x, y] = (intensity, intensity, intensity)
65
+ else:
66
+ # Dark square with color gradient
67
+ r = int(100 * (x % square_size) / square_size)
68
+ g = int(100 * (y % square_size) / square_size)
69
+ b = 150
70
+ pixels[x, y] = (r, g, b)
71
+
72
+ return img
73
+
74
+
75
+ def create_mandala_image(size=(512, 512)):
76
+ """Create a mandala-like pattern."""
77
+ width, height = size
78
+ img = np.zeros((height, width, 3), dtype=np.uint8)
79
+
80
+ center_x, center_y = width // 2, height // 2
81
+
82
+ for y in range(height):
83
+ for x in range(width):
84
+ dx = x - center_x
85
+ dy = y - center_y
86
+
87
+ distance = np.sqrt(dx**2 + dy**2)
88
+ angle = np.arctan2(dy, dx)
89
+
90
+ # Create radial pattern
91
+ r = int(127 + 127 * np.sin(distance / 20 + angle * 5))
92
+ g = int(127 + 127 * np.cos(distance / 30 - angle * 3))
93
+ b = int(127 + 127 * np.sin(distance / 40 + angle * 7))
94
+
95
+ img[y, x] = [r, g, b]
96
+
97
+ return Image.fromarray(img)
98
+
99
+
100
+ def main():
101
+ """Generate all sample images."""
102
+ os.makedirs('samples', exist_ok=True)
103
+
104
+ print("Generating sample images...")
105
+
106
+ # Generate different sample images
107
+ samples = {
108
+ 'cat.jpg': create_mandala_image(),
109
+ 'landscape.jpg': create_gradient_image(),
110
+ 'portrait.jpg': create_pattern_image(),
111
+ 'checkerboard.jpg': create_checkerboard_image(),
112
+ }
113
+
114
+ for filename, image in samples.items():
115
+ filepath = os.path.join('samples', filename)
116
+ image.save(filepath, quality=95)
117
+ print(f"Created: {filepath}")
118
+
119
+ print("\n✓ All sample images created successfully!")
120
+ print("\nYou can replace these with your own high-resolution images.")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
model_cache/1000steps_2x_cat_800x550_h256_l3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ce79607bb487b7abd7ea97a170cdf16f41fd0cd4810189d7fbfb0dabcfdeac5
3
+ size 799149
model_cache/2000steps_2x_cat_800x550_h256_l3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5af8030817bc571203bc234788738d4e389d2da9b256995bcf0e4cc904818699
3
+ size 799149
pretrain_models.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pre-train SIREN models for common settings to populate cache."""
2
+ from PIL import Image
3
+ import os
4
+ from app import super_resolve_image
5
+
6
+ # Common configurations to pre-train
7
+ configs = [
8
+ # (image_path, scale, steps, hidden_features, hidden_layers, name)
9
+ ("samples/cat.jpg", 2, 2000, 256, 3, "cat"),
10
+ ("samples/landscape.jpg", 4, 3000, 256, 3, "landscape"),
11
+ ("samples/portrait.jpg", 2, 2000, 256, 3, "portrait"),
12
+ ("samples/flower.jpg", 4, 3000, 256, 4, "flower"),
13
+ ]
14
+
15
+ print("=" * 60)
16
+ print("PRE-TRAINING SIREN MODELS FOR COMMON SETTINGS")
17
+ print("=" * 60)
18
+ print()
19
+
20
+ for i, (img_path, scale, steps, h_feat, h_layers, name) in enumerate(configs, 1):
21
+ print(f"\n[{i}/{len(configs)}] Training: {name}")
22
+ print(f" Image: {img_path}")
23
+ print(f" Settings: {scale}x scale, {steps} steps, {h_feat} features, {h_layers} layers")
24
+ print("-" * 60)
25
+
26
+ try:
27
+ # Load image
28
+ image = Image.open(img_path)
29
+
30
+ # Train and cache (use_cache=True will save the model)
31
+ results = super_resolve_image(
32
+ input_image=image,
33
+ scale_factor=scale,
34
+ training_steps=steps,
35
+ hidden_features=h_feat,
36
+ hidden_layers=h_layers,
37
+ use_cache=True,
38
+ image_name=name
39
+ )
40
+
41
+ print(f" ✓ Model trained and cached successfully!")
42
+
43
+ except Exception as e:
44
+ print(f" ✗ Error: {e}")
45
+
46
+ print("\n" + "=" * 60)
47
+ print("PRE-TRAINING COMPLETE!")
48
+ print("=" * 60)
49
+
50
+ # List cached models
51
+ cache_dir = "model_cache"
52
+ if os.path.exists(cache_dir):
53
+ models = [f for f in os.listdir(cache_dir) if f.endswith('.pkl')]
54
+ print(f"\nCached models ({len(models)}):")
55
+ for model in sorted(models):
56
+ size = os.path.getsize(os.path.join(cache_dir, model)) / 1024
57
+ print(f" • {model} ({size:.1f} KB)")
58
+ else:
59
+ print("\nNo cache directory found.")
pretrain_quick.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick pre-training with reduced steps for faster caching."""
2
+ from PIL import Image
3
+ import os
4
+ from app import super_resolve_image
5
+
6
+ # Quick configurations - reduced steps for faster pre-training
7
+ configs = [
8
+ # (image_path, scale, steps, hidden_features, hidden_layers, name)
9
+ ("samples/cat.jpg", 2, 1000, 256, 3, "cat"),
10
+ ("samples/landscape.jpg", 2, 1000, 256, 3, "landscape"),
11
+ ("samples/portrait.jpg", 2, 1000, 256, 3, "portrait"),
12
+ ("samples/flower.jpg", 2, 1000, 256, 3, "flower"),
13
+ ]
14
+
15
+ print("QUICK PRE-TRAINING (1000 steps each)")
16
+ print("=" * 60)
17
+
18
+ for i, (img_path, scale, steps, h_feat, h_layers, name) in enumerate(configs, 1):
19
+ print(f"\n[{i}/{len(configs)}] {name}: {scale}x @ {steps} steps")
20
+
21
+ try:
22
+ image = Image.open(img_path)
23
+ results = super_resolve_image(
24
+ input_image=image,
25
+ scale_factor=scale,
26
+ training_steps=steps,
27
+ hidden_features=h_feat,
28
+ hidden_layers=h_layers,
29
+ use_cache=True,
30
+ image_name=name
31
+ )
32
+ print(f" ✓ Cached!")
33
+ except Exception as e:
34
+ print(f" ✗ Error: {e}")
35
+
36
+ print("\n" + "=" * 60)
37
+ print("DONE!")
38
+
39
+ # List cached models
40
+ cache_dir = "model_cache"
41
+ if os.path.exists(cache_dir):
42
+ models = [f for f in os.listdir(cache_dir) if f.endswith('.pkl')]
43
+ print(f"\nCached models: {len(models)}")
44
+ for model in sorted(models):
45
+ size = os.path.getsize(os.path.join(cache_dir, model)) / 1024
46
+ print(f" {model} ({size:.1f} KB)")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=4.0.0
4
+ numpy>=1.24.0
5
+ Pillow>=10.0.0
6
+ matplotlib>=3.7.0
samples/cat.jpg ADDED
samples/flower.jpg ADDED
samples/landscape.jpg ADDED
samples/portrait.jpg ADDED
siren.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class SineLayer(nn.Module):
7
+ """Sine activation layer for SIREN network.
8
+
9
+ Args:
10
+ in_features: Number of input features
11
+ out_features: Number of output features
12
+ bias: Whether to use bias
13
+ is_first: Whether this is the first layer (uses different initialization)
14
+ omega_0: Frequency parameter for sine activation
15
+ """
16
+
17
+ def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
18
+ super().__init__()
19
+ self.omega_0 = omega_0
20
+ self.is_first = is_first
21
+ self.in_features = in_features
22
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
23
+ self.init_weights()
24
+
25
+ def init_weights(self):
26
+ with torch.no_grad():
27
+ if self.is_first:
28
+ self.linear.weight.uniform_(-1 / self.in_features,
29
+ 1 / self.in_features)
30
+ else:
31
+ self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
32
+ np.sqrt(6 / self.in_features) / self.omega_0)
33
+
34
+ def forward(self, x):
35
+ return torch.sin(self.omega_0 * self.linear(x))
36
+
37
+
38
+ class SIREN(nn.Module):
39
+ """SIREN network for implicit neural representations.
40
+
41
+ Args:
42
+ in_features: Number of input features (2 for image coordinates)
43
+ hidden_features: Number of hidden units in each layer
44
+ hidden_layers: Number of hidden layers
45
+ out_features: Number of output features (3 for RGB)
46
+ outermost_linear: Whether to use linear activation in the last layer
47
+ first_omega_0: Frequency parameter for first layer
48
+ hidden_omega_0: Frequency parameter for hidden layers
49
+ """
50
+
51
+ def __init__(self, in_features=2, hidden_features=256, hidden_layers=3,
52
+ out_features=3, outermost_linear=True,
53
+ first_omega_0=30, hidden_omega_0=30):
54
+ super().__init__()
55
+
56
+ self.net = []
57
+ self.net.append(SineLayer(in_features, hidden_features,
58
+ is_first=True, omega_0=first_omega_0))
59
+
60
+ for i in range(hidden_layers):
61
+ self.net.append(SineLayer(hidden_features, hidden_features,
62
+ is_first=False, omega_0=hidden_omega_0))
63
+
64
+ if outermost_linear:
65
+ final_linear = nn.Linear(hidden_features, out_features)
66
+
67
+ with torch.no_grad():
68
+ final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
69
+ np.sqrt(6 / hidden_features) / hidden_omega_0)
70
+
71
+ self.net.append(final_linear)
72
+ else:
73
+ self.net.append(SineLayer(hidden_features, out_features,
74
+ is_first=False, omega_0=hidden_omega_0))
75
+
76
+ self.net = nn.Sequential(*self.net)
77
+
78
+ def forward(self, coords):
79
+ output = self.net(coords)
80
+ return output
utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import hashlib
6
+ import os
7
+ import pickle
8
+
9
+
10
+ def get_image_coordinates(H, W):
11
+ """Generate normalized coordinate grid for image.
12
+
13
+ Args:
14
+ H: Image height
15
+ W: Image width
16
+
17
+ Returns:
18
+ coords: Tensor of shape (H*W, 2) with normalized coordinates in [-1, 1]
19
+ """
20
+ x = torch.linspace(-1, 1, W)
21
+ y = torch.linspace(-1, 1, H)
22
+
23
+ # Create meshgrid
24
+ Y, X = torch.meshgrid(y, x, indexing='ij')
25
+
26
+ # Stack and reshape to (H*W, 2)
27
+ coords = torch.stack([X, Y], dim=-1).reshape(-1, 2)
28
+
29
+ return coords
30
+
31
+
32
+ def image_to_tensor(image):
33
+ """Convert PIL Image to normalized tensor.
34
+
35
+ Args:
36
+ image: PIL Image
37
+
38
+ Returns:
39
+ Tensor of shape (H*W, 3) with values in [0, 1]
40
+ """
41
+ # Convert to RGB if not already
42
+ if image.mode != 'RGB':
43
+ image = image.convert('RGB')
44
+
45
+ # Convert to tensor and normalize to [0, 1]
46
+ img_tensor = transforms.ToTensor()(image) # (C, H, W)
47
+ img_tensor = img_tensor.permute(1, 2, 0) # (H, W, C)
48
+ img_tensor = img_tensor.reshape(-1, 3) # (H*W, 3)
49
+
50
+ return img_tensor
51
+
52
+
53
+ def tensor_to_image(tensor, H, W):
54
+ """Convert tensor back to PIL Image.
55
+
56
+ Args:
57
+ tensor: Tensor of shape (H*W, 3) with values in [0, 1]
58
+ H: Image height
59
+ W: Image width
60
+
61
+ Returns:
62
+ PIL Image
63
+ """
64
+ # Reshape to (H, W, C)
65
+ img = tensor.reshape(H, W, 3)
66
+
67
+ # Clamp to [0, 1]
68
+ img = torch.clamp(img, 0, 1)
69
+
70
+ # Convert to numpy and scale to [0, 255]
71
+ img = (img.cpu().numpy() * 255).astype(np.uint8)
72
+
73
+ # Convert to PIL Image
74
+ return Image.fromarray(img)
75
+
76
+
77
+ def downsample_image(image, scale_factor):
78
+ """Downsample image by scale_factor.
79
+
80
+ Args:
81
+ image: PIL Image
82
+ scale_factor: Downsampling factor (e.g., 2 for half size)
83
+
84
+ Returns:
85
+ Downsampled PIL Image
86
+ """
87
+ W, H = image.size
88
+ new_W = W // scale_factor
89
+ new_H = H // scale_factor
90
+
91
+ return image.resize((new_W, new_H), Image.BICUBIC)
92
+
93
+
94
+ def train_siren(model, coords, pixels, num_steps=2000, learning_rate=1e-4, device='cpu'):
95
+ """Train SIREN model on image.
96
+
97
+ Args:
98
+ model: SIREN model
99
+ coords: Coordinate tensor (H*W, 2)
100
+ pixels: Pixel values tensor (H*W, 3)
101
+ num_steps: Number of training steps
102
+ learning_rate: Learning rate
103
+ device: Device to train on
104
+
105
+ Returns:
106
+ Trained model and training losses
107
+ """
108
+ model = model.to(device)
109
+ coords = coords.to(device)
110
+ pixels = pixels.to(device)
111
+
112
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
113
+
114
+ losses = []
115
+
116
+ for step in range(num_steps):
117
+ # Forward pass
118
+ pred_pixels = model(coords)
119
+
120
+ # Compute loss
121
+ loss = torch.nn.functional.mse_loss(pred_pixels, pixels)
122
+
123
+ # Backward pass
124
+ optimizer.zero_grad()
125
+ loss.backward()
126
+ optimizer.step()
127
+
128
+ losses.append(loss.item())
129
+
130
+ # Print progress
131
+ if (step + 1) % 200 == 0:
132
+ print(f"Step {step + 1}/{num_steps}, Loss: {loss.item():.6f}")
133
+
134
+ return model, losses
135
+
136
+
137
+ def compute_psnr(img1, img2):
138
+ """Compute Peak Signal-to-Noise Ratio between two images.
139
+
140
+ Args:
141
+ img1: First image tensor (H*W, 3) in [0, 1]
142
+ img2: Second image tensor (H*W, 3) in [0, 1]
143
+
144
+ Returns:
145
+ PSNR value in dB
146
+ """
147
+ mse = torch.nn.functional.mse_loss(img1, img2)
148
+ if mse == 0:
149
+ return float('inf')
150
+ psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
151
+ return psnr.item()
152
+
153
+
154
+ def compute_mae(img1, img2):
155
+ """Compute Mean Absolute Error between two images.
156
+
157
+ Args:
158
+ img1: First image tensor (H*W, 3) in [0, 1]
159
+ img2: Second image tensor (H*W, 3) in [0, 1]
160
+
161
+ Returns:
162
+ MAE value
163
+ """
164
+ mae = torch.nn.functional.l1_loss(img1, img2)
165
+ return mae.item()
166
+
167
+
168
+ def compute_ssim_simple(img1, img2, window_size=11):
169
+ """Compute simplified SSIM between two images.
170
+
171
+ Args:
172
+ img1: First image tensor (H*W, 3) in [0, 1]
173
+ img2: Second image tensor (H*W, 3) in [0, 1]
174
+ window_size: Window size for local statistics
175
+
176
+ Returns:
177
+ SSIM value in [0, 1]
178
+ """
179
+ # Simplified SSIM - compute channel-wise
180
+ c1 = 0.01 ** 2
181
+ c2 = 0.03 ** 2
182
+
183
+ mu1 = img1.mean()
184
+ mu2 = img2.mean()
185
+
186
+ sigma1_sq = ((img1 - mu1) ** 2).mean()
187
+ sigma2_sq = ((img2 - mu2) ** 2).mean()
188
+ sigma12 = ((img1 - mu1) * (img2 - mu2)).mean()
189
+
190
+ ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \
191
+ ((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2))
192
+
193
+ return ssim.item()
194
+
195
+
196
+ def get_model_cache_path(image_path, scale_factor, training_steps, hidden_features, hidden_layers):
197
+ """Generate cache path for trained model.
198
+
199
+ Args:
200
+ image_path: Path to image
201
+ scale_factor: Upscaling factor
202
+ training_steps: Number of training steps
203
+ hidden_features: Network width
204
+ hidden_layers: Network depth
205
+
206
+ Returns:
207
+ Cache file path
208
+ """
209
+ cache_dir = "model_cache"
210
+ os.makedirs(cache_dir, exist_ok=True)
211
+
212
+ # Extract image name from path (without extension)
213
+ if "/" in image_path:
214
+ image_name = image_path.split("/")[-1].split(".")[0]
215
+ else:
216
+ image_name = image_path.split(".")[0]
217
+
218
+ # Create descriptive filename
219
+ filename = f"{training_steps}steps_{scale_factor}x_{image_name}_h{hidden_features}_l{hidden_layers}.pkl"
220
+
221
+ return os.path.join(cache_dir, filename)
222
+
223
+
224
+ def save_model(model, cache_path):
225
+ """Save model to cache.
226
+
227
+ Args:
228
+ model: SIREN model
229
+ cache_path: Path to save model
230
+ """
231
+ with open(cache_path, 'wb') as f:
232
+ pickle.dump(model.state_dict(), f)
233
+ print(f"Model saved to cache: {cache_path}")
234
+
235
+
236
+ def load_model(model, cache_path):
237
+ """Load model from cache.
238
+
239
+ Args:
240
+ model: SIREN model (architecture must match)
241
+ cache_path: Path to cached model
242
+
243
+ Returns:
244
+ Loaded model or None if cache doesn't exist
245
+ """
246
+ if os.path.exists(cache_path):
247
+ with open(cache_path, 'rb') as f:
248
+ model.load_state_dict(pickle.load(f))
249
+ print(f"Model loaded from cache: {cache_path}")
250
+ return model
251
+ return None