File size: 843 Bytes
8abfb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import matplotlib.pyplot as plt
import numpy as np
import torch
def plot_losses(log_dir):
"""Plot training losses from TensorBoard logs"""
# Note: In practice, you'd use TensorBoard directly
pass
def save_checkpoint(model, optimizer, epoch, path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
def show_samples(samples):
"""Display generated samples"""
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(samples.numpy(), (1, 2, 0)))
plt.axis('off')
plt.show() |