Spaces:
Build error
Build error
File size: 1,224 Bytes
3aa6cf7 4462254 3aa6cf7 4462254 3aa6cf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
import os
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return 'cpu'
def set_seed(seed=1337):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def save_model(model, optimizer, loss, epoch, path='model.pt'):
# Convert model to half precision
model_to_save = model.half()
# Save in half precision
torch.save({
'model_state_dict': model_to_save.state_dict(),
'loss': loss,
'epoch': epoch
}, path, _use_new_zipfile_serialization=False)
# Convert back to original precision
model.float()
print(f"Model saved to {path}")
def load_model(model, optimizer=None, path='model.pt'):
if os.path.exists(path):
checkpoint = torch.load(path, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint.get('epoch', 0), checkpoint.get('loss', None)
return 0, None |