File size: 6,516 Bytes
bfc6d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""Model loading and caching for A1-Max MuQ LoRA inference."""

from pathlib import Path
from typing import List, Optional

import torch
import torch.nn as nn

from constants import MODEL_CONFIG, N_FOLDS


class A1MaxInferenceHead(nn.Module):
    """Inference-only version of MuQLoRAMaxModel's predict_scores path.

    Replicates the architecture needed for score prediction:
    - Attention pooling: [B, T, D] -> [B, D]
    - Encoder: 2-layer MLP [B, D] -> [B, hidden_dim]
    - Regression head: MLP + sigmoid [B, hidden_dim] -> [B, num_labels]

    Does NOT include ranking/contrastive/comparator modules (training-only).
    """

    def __init__(
        self,
        input_dim: int = 1024,
        hidden_dim: int = 512,
        num_labels: int = 6,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.num_labels = num_labels

        # Attention pooling (matches MuQLoRAModel.attn)
        self.attn = nn.Sequential(
            nn.Linear(input_dim, 256), nn.Tanh(), nn.Linear(256, 1)
        )

        # Shared encoder (matches MuQLoRAModel.encoder)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )

        # Regression head (matches MuQLoRAModel.regression_head)
        self.regression_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_labels),
            nn.Sigmoid(),
        )

    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        """Predict quality scores from frame embeddings.

        Args:
            embeddings: Frame embeddings [B, T, D] or [T, D].

        Returns:
            Scores [B, num_labels] or [num_labels] in [0, 1].
        """
        squeeze_output = False
        if embeddings.dim() == 2:
            embeddings = embeddings.unsqueeze(0)
            squeeze_output = True

        # Attention pool
        scores = self.attn(embeddings).squeeze(-1)  # [B, T]
        w = torch.softmax(scores, dim=-1).unsqueeze(-1)  # [B, T, 1]
        pooled = (embeddings * w).sum(1)  # [B, D]

        # Encode
        z = self.encoder(pooled)  # [B, hidden_dim]

        # Predict
        result = self.regression_head(z)  # [B, num_labels]

        return result.squeeze(0) if squeeze_output else result


class ModelCache:
    """Singleton cache for loaded models."""

    _instance: Optional["ModelCache"] = None

    def __new__(cls) -> "ModelCache":
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialized = False
        return cls._instance

    def __init__(self):
        if self._initialized:
            return
        self.muq_model = None
        self.muq_heads: List[A1MaxInferenceHead] = []
        self.device = None
        self._initialized = True

    def initialize(self, device: str = "cuda", checkpoint_dir: Optional[Path] = None):
        """Load MuQ model and A1-Max prediction heads. Called once on container start."""
        if self.muq_model is not None:
            return

        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        print(f"Initializing A1-Max models on {self.device}...")

        # Load MuQ from HuggingFace
        print("Loading MuQ-large-msd-iter...")
        try:
            from muq import MuQ
            self.muq_model = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
            self.muq_model = self.muq_model.to(self.device)
            self.muq_model.eval()
            print("MuQ loaded successfully")
        except ImportError as e:
            raise ImportError(
                "MuQ library not found. Install with: pip install muq"
            ) from e

        # Load A1-Max prediction heads (4 folds)
        print("Loading A1-Max prediction heads...")
        checkpoint_dir = checkpoint_dir or Path("/repository/checkpoints")
        if not checkpoint_dir.exists():
            checkpoint_dir = Path("/app/checkpoints")

        for fold in range(N_FOLDS):
            ckpt_path = checkpoint_dir / f"fold_{fold}" / "best.ckpt"
            # Also try the epoch-based naming from sweep
            if not ckpt_path.exists():
                fold_dir = checkpoint_dir / f"fold_{fold}"
                if fold_dir.exists():
                    ckpts = sorted(fold_dir.glob("*.ckpt"))
                    if ckpts:
                        ckpt_path = ckpts[0]
            if ckpt_path.exists():
                head = self._load_a1max_head(ckpt_path)
                self.muq_heads.append(head)
                print(f"  Loaded fold {fold} from {ckpt_path}")
            else:
                print(f"  Warning: No checkpoint found for fold {fold}")

        print(f"Initialization complete. {len(self.muq_heads)} heads loaded.")

    def _load_a1max_head(self, ckpt_path: Path) -> A1MaxInferenceHead:
        """Load an A1MaxInferenceHead from PyTorch Lightning checkpoint."""
        checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=False)

        hparams = checkpoint.get("hyper_parameters", {})

        head = A1MaxInferenceHead(
            input_dim=hparams.get("input_dim", MODEL_CONFIG["input_dim"]),
            hidden_dim=hparams.get("hidden_dim", MODEL_CONFIG["hidden_dim"]),
            num_labels=hparams.get("num_labels", MODEL_CONFIG["num_labels"]),
            dropout=hparams.get("dropout", MODEL_CONFIG["dropout"]),
        )

        # Load state dict from Lightning checkpoint
        state_dict = checkpoint["state_dict"]

        # Map Lightning keys to inference head keys
        # Lightning saves as: attn.0.weight, encoder.0.weight, regression_head.0.weight, etc.
        head_state = {}
        for key, value in state_dict.items():
            if key.startswith("attn.") or key.startswith("encoder.") or key.startswith("regression_head."):
                head_state[key] = value

        head.load_state_dict(head_state, strict=True)

        head.to(self.device)
        head.eval()
        return head


_cache: Optional[ModelCache] = None


def get_model_cache() -> ModelCache:
    """Get the global model cache instance."""
    global _cache
    if _cache is None:
        _cache = ModelCache()
    return _cache