siren-super-resolution / pretrain_quick.py
Nipun's picture
Complete SIREN super-resolution demo with improvements
691ba3c
"""Quick pre-training with reduced steps for faster caching."""
from PIL import Image
import os
from app import super_resolve_image
# Quick configurations - reduced steps for faster pre-training
configs = [
# (image_path, scale, steps, hidden_features, hidden_layers, name)
("samples/cat.jpg", 2, 1000, 256, 3, "cat"),
("samples/landscape.jpg", 2, 1000, 256, 3, "landscape"),
("samples/portrait.jpg", 2, 1000, 256, 3, "portrait"),
("samples/flower.jpg", 2, 1000, 256, 3, "flower"),
]
print("QUICK PRE-TRAINING (1000 steps each)")
print("=" * 60)
for i, (img_path, scale, steps, h_feat, h_layers, name) in enumerate(configs, 1):
print(f"\n[{i}/{len(configs)}] {name}: {scale}x @ {steps} steps")
try:
image = Image.open(img_path)
results = super_resolve_image(
input_image=image,
scale_factor=scale,
training_steps=steps,
hidden_features=h_feat,
hidden_layers=h_layers,
use_cache=True,
image_name=name
)
print(f" βœ“ Cached!")
except Exception as e:
print(f" βœ— Error: {e}")
print("\n" + "=" * 60)
print("DONE!")
# List cached models
cache_dir = "model_cache"
if os.path.exists(cache_dir):
models = [f for f in os.listdir(cache_dir) if f.endswith('.pkl')]
print(f"\nCached models: {len(models)}")
for model in sorted(models):
size = os.path.getsize(os.path.join(cache_dir, model)) / 1024
print(f" {model} ({size:.1f} KB)")