File size: 1,931 Bytes
72b2f6d 33efa44 72b2f6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# Model configuration and calibration dataclasses
# Re-exported from the main package for demo use
import json
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@dataclass
class ModelConfig:
"""Configuration for a trained early exit model."""
model_name: str
num_heads: int
head_layer_indices: List[int]
quantization: str # "none", "4bit", "8bit"
hidden_size: int
vocab_size: int
num_hidden_layers: int
training_config: Optional[Dict] = None
@classmethod
def from_json(cls, path: str) -> "ModelConfig":
with open(path, "r") as f:
data = json.load(f)
return cls(
model_name=data["model_name"],
num_heads=data["num_heads"],
head_layer_indices=data["head_layer_indices"],
quantization=data["quantization"],
hidden_size=data["hidden_size"],
vocab_size=data["vocab_size"],
num_hidden_layers=data["num_hidden_layers"],
training_config=data.get("training_config"),
)
@dataclass
class CalibrationResult:
"""Calibration results with thresholds per head per accuracy level."""
model_config_path: str
calibration_dataset: str
calibration_samples: int
uncertainty_metric: str # "entropy" or "confidence"
accuracy_levels: List[float]
thresholds: Dict[str, Dict[str, float]] = field(default_factory=dict)
statistics: Dict[str, Dict] = field(default_factory=dict)
@classmethod
def from_json(cls, path: str) -> "CalibrationResult":
with open(path, "r") as f:
data = json.load(f)
return cls(**data)
def get_thresholds_for_level(self, accuracy_level: float) -> Dict[int, float]:
"""Get all thresholds for a given accuracy level."""
level_key = f"{accuracy_level:.2f}"
return {int(k): v for k, v in self.thresholds[level_key].items()}
|