File size: 1,862 Bytes
691ba3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Pre-train SIREN models for common settings to populate cache."""
from PIL import Image
import os
from app import super_resolve_image

# Common configurations to pre-train
configs = [
    # (image_path, scale, steps, hidden_features, hidden_layers, name)
    ("samples/cat.jpg", 2, 2000, 256, 3, "cat"),
    ("samples/landscape.jpg", 4, 3000, 256, 3, "landscape"),
    ("samples/portrait.jpg", 2, 2000, 256, 3, "portrait"),
    ("samples/flower.jpg", 4, 3000, 256, 4, "flower"),
]

print("=" * 60)
print("PRE-TRAINING SIREN MODELS FOR COMMON SETTINGS")
print("=" * 60)
print()

for i, (img_path, scale, steps, h_feat, h_layers, name) in enumerate(configs, 1):
    print(f"\n[{i}/{len(configs)}] Training: {name}")
    print(f"  Image: {img_path}")
    print(f"  Settings: {scale}x scale, {steps} steps, {h_feat} features, {h_layers} layers")
    print("-" * 60)

    try:
        # Load image
        image = Image.open(img_path)

        # Train and cache (use_cache=True will save the model)
        results = super_resolve_image(
            input_image=image,
            scale_factor=scale,
            training_steps=steps,
            hidden_features=h_feat,
            hidden_layers=h_layers,
            use_cache=True,
            image_name=name
        )

        print(f"  ✓ Model trained and cached successfully!")

    except Exception as e:
        print(f"  ✗ Error: {e}")

print("\n" + "=" * 60)
print("PRE-TRAINING COMPLETE!")
print("=" * 60)

# List cached models
cache_dir = "model_cache"
if os.path.exists(cache_dir):
    models = [f for f in os.listdir(cache_dir) if f.endswith('.pkl')]
    print(f"\nCached models ({len(models)}):")
    for model in sorted(models):
        size = os.path.getsize(os.path.join(cache_dir, model)) / 1024
        print(f"  • {model} ({size:.1f} KB)")
else:
    print("\nNo cache directory found.")