File size: 2,788 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torchvision
from torchvision.utils import save_image, make_grid
import os
import argparse
from datetime import datetime
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler_simple import FrequencyAwareNoise
from sample_simple import simple_sample

def load_model(checkpoint_path, device):
    """Load model from checkpoint"""
    print(f"Loading model from: {checkpoint_path}")
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Initialize model and noise scheduler
    if 'config' in checkpoint:
        config = checkpoint['config']
    else:
        config = Config()  # Fallback to default config
    
    model = SmoothDiffusionUNet(config).to(device)
    noise_scheduler = FrequencyAwareNoise(config)
    
    # Load model state
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        epoch = checkpoint.get('epoch', 'unknown')
        loss = checkpoint.get('loss', 'unknown')
        print(f"Loaded model from epoch {epoch}, loss: {loss}")
    else:
        # Handle simple state dict (final model)
        model.load_state_dict(checkpoint)
        print("Loaded model state dict")
    
    return model, noise_scheduler, config

def test_checkpoint(checkpoint_path, device, n_samples=16):
    """Test a single checkpoint with working sampler"""
    model, noise_scheduler, config = load_model(checkpoint_path, device)
    
    # Generate samples
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = f"test_samples_simple_{timestamp}.png"
    
    print(f"Testing checkpoint with {n_samples} samples...")
    samples, grid = simple_sample(model, noise_scheduler, device, n_samples=n_samples)
    
    # Save the results
    save_image(grid, save_path, normalize=False)
    print(f"Samples saved to: {save_path}")
    
    return samples, grid

def main():
    parser = argparse.ArgumentParser(description='Test trained diffusion model (simple version)')
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to checkpoint file')
    parser.add_argument('--n_samples', type=int, default=16, help='Number of samples to generate')
    parser.add_argument('--device', type=str, default='auto', help='Device to use (cuda/cpu/auto)')
    
    args = parser.parse_args()
    
    # Setup device
    if args.device == 'auto':
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args.device)
    
    print(f"Using device: {device}")
    
    # Test the checkpoint
    print("=== Testing Checkpoint with Simple DDPM ===")
    test_checkpoint(args.checkpoint, device, args.n_samples)

if __name__ == "__main__":
    main()