File size: 6,516 Bytes
bfc6d2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """Model loading and caching for A1-Max MuQ LoRA inference."""
from pathlib import Path
from typing import List, Optional
import torch
import torch.nn as nn
from constants import MODEL_CONFIG, N_FOLDS
class A1MaxInferenceHead(nn.Module):
"""Inference-only version of MuQLoRAMaxModel's predict_scores path.
Replicates the architecture needed for score prediction:
- Attention pooling: [B, T, D] -> [B, D]
- Encoder: 2-layer MLP [B, D] -> [B, hidden_dim]
- Regression head: MLP + sigmoid [B, hidden_dim] -> [B, num_labels]
Does NOT include ranking/contrastive/comparator modules (training-only).
"""
def __init__(
self,
input_dim: int = 1024,
hidden_dim: int = 512,
num_labels: int = 6,
dropout: float = 0.2,
):
super().__init__()
self.num_labels = num_labels
# Attention pooling (matches MuQLoRAModel.attn)
self.attn = nn.Sequential(
nn.Linear(input_dim, 256), nn.Tanh(), nn.Linear(256, 1)
)
# Shared encoder (matches MuQLoRAModel.encoder)
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
)
# Regression head (matches MuQLoRAModel.regression_head)
self.regression_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, num_labels),
nn.Sigmoid(),
)
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
"""Predict quality scores from frame embeddings.
Args:
embeddings: Frame embeddings [B, T, D] or [T, D].
Returns:
Scores [B, num_labels] or [num_labels] in [0, 1].
"""
squeeze_output = False
if embeddings.dim() == 2:
embeddings = embeddings.unsqueeze(0)
squeeze_output = True
# Attention pool
scores = self.attn(embeddings).squeeze(-1) # [B, T]
w = torch.softmax(scores, dim=-1).unsqueeze(-1) # [B, T, 1]
pooled = (embeddings * w).sum(1) # [B, D]
# Encode
z = self.encoder(pooled) # [B, hidden_dim]
# Predict
result = self.regression_head(z) # [B, num_labels]
return result.squeeze(0) if squeeze_output else result
class ModelCache:
"""Singleton cache for loaded models."""
_instance: Optional["ModelCache"] = None
def __new__(cls) -> "ModelCache":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self.muq_model = None
self.muq_heads: List[A1MaxInferenceHead] = []
self.device = None
self._initialized = True
def initialize(self, device: str = "cuda", checkpoint_dir: Optional[Path] = None):
"""Load MuQ model and A1-Max prediction heads. Called once on container start."""
if self.muq_model is not None:
return
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
print(f"Initializing A1-Max models on {self.device}...")
# Load MuQ from HuggingFace
print("Loading MuQ-large-msd-iter...")
try:
from muq import MuQ
self.muq_model = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
self.muq_model = self.muq_model.to(self.device)
self.muq_model.eval()
print("MuQ loaded successfully")
except ImportError as e:
raise ImportError(
"MuQ library not found. Install with: pip install muq"
) from e
# Load A1-Max prediction heads (4 folds)
print("Loading A1-Max prediction heads...")
checkpoint_dir = checkpoint_dir or Path("/repository/checkpoints")
if not checkpoint_dir.exists():
checkpoint_dir = Path("/app/checkpoints")
for fold in range(N_FOLDS):
ckpt_path = checkpoint_dir / f"fold_{fold}" / "best.ckpt"
# Also try the epoch-based naming from sweep
if not ckpt_path.exists():
fold_dir = checkpoint_dir / f"fold_{fold}"
if fold_dir.exists():
ckpts = sorted(fold_dir.glob("*.ckpt"))
if ckpts:
ckpt_path = ckpts[0]
if ckpt_path.exists():
head = self._load_a1max_head(ckpt_path)
self.muq_heads.append(head)
print(f" Loaded fold {fold} from {ckpt_path}")
else:
print(f" Warning: No checkpoint found for fold {fold}")
print(f"Initialization complete. {len(self.muq_heads)} heads loaded.")
def _load_a1max_head(self, ckpt_path: Path) -> A1MaxInferenceHead:
"""Load an A1MaxInferenceHead from PyTorch Lightning checkpoint."""
checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=False)
hparams = checkpoint.get("hyper_parameters", {})
head = A1MaxInferenceHead(
input_dim=hparams.get("input_dim", MODEL_CONFIG["input_dim"]),
hidden_dim=hparams.get("hidden_dim", MODEL_CONFIG["hidden_dim"]),
num_labels=hparams.get("num_labels", MODEL_CONFIG["num_labels"]),
dropout=hparams.get("dropout", MODEL_CONFIG["dropout"]),
)
# Load state dict from Lightning checkpoint
state_dict = checkpoint["state_dict"]
# Map Lightning keys to inference head keys
# Lightning saves as: attn.0.weight, encoder.0.weight, regression_head.0.weight, etc.
head_state = {}
for key, value in state_dict.items():
if key.startswith("attn.") or key.startswith("encoder.") or key.startswith("regression_head."):
head_state[key] = value
head.load_state_dict(head_state, strict=True)
head.to(self.device)
head.eval()
return head
_cache: Optional[ModelCache] = None
def get_model_cache() -> ModelCache:
"""Get the global model cache instance."""
global _cache
if _cache is None:
_cache = ModelCache()
return _cache
|