siren-super-resolution / pretrain_models.py
Nipun's picture
Complete SIREN super-resolution demo with improvements
691ba3c
"""Pre-train SIREN models for common settings to populate cache."""
from PIL import Image
import os
from app import super_resolve_image
# Common configurations to pre-train
configs = [
# (image_path, scale, steps, hidden_features, hidden_layers, name)
("samples/cat.jpg", 2, 2000, 256, 3, "cat"),
("samples/landscape.jpg", 4, 3000, 256, 3, "landscape"),
("samples/portrait.jpg", 2, 2000, 256, 3, "portrait"),
("samples/flower.jpg", 4, 3000, 256, 4, "flower"),
]
print("=" * 60)
print("PRE-TRAINING SIREN MODELS FOR COMMON SETTINGS")
print("=" * 60)
print()
for i, (img_path, scale, steps, h_feat, h_layers, name) in enumerate(configs, 1):
print(f"\n[{i}/{len(configs)}] Training: {name}")
print(f" Image: {img_path}")
print(f" Settings: {scale}x scale, {steps} steps, {h_feat} features, {h_layers} layers")
print("-" * 60)
try:
# Load image
image = Image.open(img_path)
# Train and cache (use_cache=True will save the model)
results = super_resolve_image(
input_image=image,
scale_factor=scale,
training_steps=steps,
hidden_features=h_feat,
hidden_layers=h_layers,
use_cache=True,
image_name=name
)
print(f" βœ“ Model trained and cached successfully!")
except Exception as e:
print(f" βœ— Error: {e}")
print("\n" + "=" * 60)
print("PRE-TRAINING COMPLETE!")
print("=" * 60)
# List cached models
cache_dir = "model_cache"
if os.path.exists(cache_dir):
models = [f for f in os.listdir(cache_dir) if f.endswith('.pkl')]
print(f"\nCached models ({len(models)}):")
for model in sorted(models):
size = os.path.getsize(os.path.join(cache_dir, model)) / 1024
print(f" β€’ {model} ({size:.1f} KB)")
else:
print("\nNo cache directory found.")