File size: 843 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_losses(log_dir):
    """Plot training losses from TensorBoard logs"""
    # Note: In practice, you'd use TensorBoard directly
    pass

def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

def show_samples(samples):
    """Display generated samples"""
    plt.figure(figsize=(10, 10))
    plt.imshow(np.transpose(samples.numpy(), (1, 2, 0)))
    plt.axis('off')
    plt.show()