Spaces:
Runtime error
Runtime error
| import yaml | |
| import json | |
| import torch | |
| import random | |
| import numpy as np | |
| from pathlib import Path | |
| def load_config(config_path): | |
| with open(config_path, 'r') as f: | |
| if config_path.endswith('.yaml') or config_path.endswith('.yml'): | |
| config = yaml.safe_load(f) | |
| elif config_path.endswith('.json'): | |
| config = json.load(f) | |
| else: | |
| raise ValueError(f"Unsupported config format: {config_path}") | |
| return config | |
| def save_config(config, save_path): | |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(save_path, 'w') as f: | |
| if save_path.endswith('.yaml') or save_path.endswith('.yml'): | |
| yaml.dump(config, f, default_flow_style=False) | |
| elif save_path.endswith('.json'): | |
| json.dump(config, f, indent=2) | |
| else: | |
| raise ValueError(f"Unsupported config format: {save_path}") | |
| def get_device(device_name='auto'): | |
| if device_name == 'auto': | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| else: | |
| device = torch.device(device_name) | |
| return device | |
| def set_seed(seed=42): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False |