File size: 1,997 Bytes
6810eb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import random
import torch
import yaml
import json

def load_config(config_path: str):
    """
    Load configuration from a YAML or JSON file.
    """
    ext = os.path.splitext(config_path)[1].lower()
    if ext in (".yaml", ".yml"):
        with open(config_path, 'r') as f:
            return yaml.safe_load(f)
    elif ext == ".json":
        with open(config_path, 'r') as f:
            return json.load(f)
    else:
        raise ValueError("Unsupported config file format: " + ext)

def save_checkpoint(model, optimizer, filepath: str):
    """
    Save model and optimizer state to a file.
    """
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
    optim_state = optimizer.state_dict() if optimizer is not None else None
    checkpoint = {"model_state": model_state}
    if optim_state is not None:
        checkpoint["optimizer_state"] = optim_state
    torch.save(checkpoint, filepath)

def load_checkpoint(model, optimizer, filepath: str, device: str):
    """
    Load model (and optimizer) state from a checkpoint file.
    """
    checkpoint = torch.load(filepath, map_location=device)
    model_state = checkpoint.get("model_state", checkpoint)
    optim_state = checkpoint.get("optimizer_state")
    # Load model parameters
    if hasattr(model, "module"):
        model.module.load_state_dict(model_state)
    else:
        model.load_state_dict(model_state)
    # Load optimizer parameters if available
    if optimizer is not None and optim_state is not None:
        optimizer.load_state_dict(optim_state)

def set_seed(seed: int):
    """
    Set random seed for reproducibility.
    """
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_device():
    """
    Return 'cuda' if GPU is available, else 'cpu'.
    """
    return "cuda" if torch.cuda.is_available() else "cpu"