Spaces:
Sleeping
Sleeping
File size: 3,157 Bytes
1601799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""
Utility functions for LexiMind.
Consolidated utilities including:
- Model checkpoint I/O
- Label metadata handling
- Seed management for reproducibility
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List
import numpy as np
import torch
# --------------- Checkpoint I/O ---------------
def save_checkpoint(model: torch.nn.Module, path: str | Path) -> None:
"""Save model state dict, handling torch.compile artifacts."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
# Strip '_orig_mod.' prefix from compiled models
state_dict = {k.replace("_orig_mod.", ""): v for k, v in model.state_dict().items()}
torch.save(state_dict, path)
def load_checkpoint(model: torch.nn.Module, path: str | Path) -> None:
"""Load model state dict, handling torch.compile artifacts."""
state = torch.load(path, map_location="cpu", weights_only=True)
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
model.load_state_dict(state)
# --------------- Label Metadata ---------------
@dataclass
class LabelMetadata:
"""Container for emotion and topic label vocabularies."""
emotion: List[str]
topic: List[str]
@property
def num_emotions(self) -> int:
return len(self.emotion)
@property
def num_topics(self) -> int:
return len(self.topic)
def load_labels(path: str | Path) -> LabelMetadata:
"""Load label metadata from JSON file."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Labels not found: {path}")
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
emotion = data.get("emotion") or data.get("emotions", [])
topic = data.get("topic") or data.get("topics", [])
if not emotion or not topic:
raise ValueError("Labels file must contain 'emotion' and 'topic' lists")
return LabelMetadata(emotion=emotion, topic=topic)
def save_labels(labels: LabelMetadata, path: str | Path) -> None:
"""Save label metadata to JSON file."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump({"emotion": labels.emotion, "topic": labels.topic}, f, indent=2)
# --------------- Reproducibility ---------------
def set_seed(seed: int) -> None:
"""Set seeds for reproducibility across all RNGs."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# --------------- Config Loading ---------------
@dataclass
class Config:
"""Simple config wrapper."""
data: dict
def load_yaml(path: str | Path) -> Config:
"""Load YAML configuration file."""
import yaml
with Path(path).open("r", encoding="utf-8") as f:
content = yaml.safe_load(f)
if not isinstance(content, dict):
raise ValueError(f"YAML '{path}' must contain a mapping")
return Config(data=content)
|