File size: 3,157 Bytes
1601799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Utility functions for LexiMind.

Consolidated utilities including:
- Model checkpoint I/O
- Label metadata handling
- Seed management for reproducibility

Author: Oliver Perrin
Date: December 2025
"""

from __future__ import annotations

import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List

import numpy as np
import torch

# --------------- Checkpoint I/O ---------------


def save_checkpoint(model: torch.nn.Module, path: str | Path) -> None:
    """Save model state dict, handling torch.compile artifacts."""
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    
    # Strip '_orig_mod.' prefix from compiled models
    state_dict = {k.replace("_orig_mod.", ""): v for k, v in model.state_dict().items()}
    torch.save(state_dict, path)


def load_checkpoint(model: torch.nn.Module, path: str | Path) -> None:
    """Load model state dict, handling torch.compile artifacts."""
    state = torch.load(path, map_location="cpu", weights_only=True)
    state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
    model.load_state_dict(state)


# --------------- Label Metadata ---------------


@dataclass
class LabelMetadata:
    """Container for emotion and topic label vocabularies."""
    
    emotion: List[str]
    topic: List[str]

    @property
    def num_emotions(self) -> int:
        return len(self.emotion)

    @property
    def num_topics(self) -> int:
        return len(self.topic)


def load_labels(path: str | Path) -> LabelMetadata:
    """Load label metadata from JSON file."""
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"Labels not found: {path}")
    
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    
    emotion = data.get("emotion") or data.get("emotions", [])
    topic = data.get("topic") or data.get("topics", [])
    
    if not emotion or not topic:
        raise ValueError("Labels file must contain 'emotion' and 'topic' lists")
    
    return LabelMetadata(emotion=emotion, topic=topic)


def save_labels(labels: LabelMetadata, path: str | Path) -> None:
    """Save label metadata to JSON file."""
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    
    with path.open("w", encoding="utf-8") as f:
        json.dump({"emotion": labels.emotion, "topic": labels.topic}, f, indent=2)


# --------------- Reproducibility ---------------


def set_seed(seed: int) -> None:
    """Set seeds for reproducibility across all RNGs."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# --------------- Config Loading ---------------


@dataclass
class Config:
    """Simple config wrapper."""
    data: dict


def load_yaml(path: str | Path) -> Config:
    """Load YAML configuration file."""
    import yaml
    with Path(path).open("r", encoding="utf-8") as f:
        content = yaml.safe_load(f)
    if not isinstance(content, dict):
        raise ValueError(f"YAML '{path}' must contain a mapping")
    return Config(data=content)