sentiment_anals / src /utils /config.py
abdou21367's picture
Upload 64 files
839c56d verified
import yaml
import json
import torch
import random
import numpy as np
from pathlib import Path
def load_config(config_path):
with open(config_path, 'r') as f:
if config_path.endswith('.yaml') or config_path.endswith('.yml'):
config = yaml.safe_load(f)
elif config_path.endswith('.json'):
config = json.load(f)
else:
raise ValueError(f"Unsupported config format: {config_path}")
return config
def save_config(config, save_path):
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
with open(save_path, 'w') as f:
if save_path.endswith('.yaml') or save_path.endswith('.yml'):
yaml.dump(config, f, default_flow_style=False)
elif save_path.endswith('.json'):
json.dump(config, f, indent=2)
else:
raise ValueError(f"Unsupported config format: {save_path}")
def get_device(device_name='auto'):
if device_name == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(device_name)
return device
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False