Spaces:
Build error
Build error
| import torch | |
| import os | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| return 'cuda' | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| return 'cpu' | |
| def set_seed(seed=1337): | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| def save_model(model, optimizer, loss, epoch, path='model.pt'): | |
| # Convert model to half precision | |
| model_to_save = model.half() | |
| # Save in half precision | |
| torch.save({ | |
| 'model_state_dict': model_to_save.state_dict(), | |
| 'loss': loss, | |
| 'epoch': epoch | |
| }, path, _use_new_zipfile_serialization=False) | |
| # Convert back to original precision | |
| model.float() | |
| print(f"Model saved to {path}") | |
| def load_model(model, optimizer=None, path='model.pt'): | |
| if os.path.exists(path): | |
| checkpoint = torch.load(path, weights_only=True) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| if optimizer is not None and 'optimizer_state_dict' in checkpoint: | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| return checkpoint.get('epoch', 0), checkpoint.get('loss', None) | |
| return 0, None |