File size: 2,365 Bytes
bfc6d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""A1-Max MuQ inference - MuQ embedding extraction and prediction."""

import numpy as np
import torch

from constants import MODEL_CONFIG
from models.loader import ModelCache


@torch.no_grad()
def extract_muq_embeddings(
    audio: np.ndarray,
    cache: ModelCache,
    layer_start: int = None,
    layer_end: int = None,
    max_frames: int = None,
) -> torch.Tensor:
    """Extract MuQ embeddings from audio waveform.

    Averages hidden states from layers 9-12 (best performing range).

    Args:
        audio: Audio waveform at 24kHz
        cache: Model cache with loaded MuQ model
        layer_start: Start layer (inclusive), default 9
        layer_end: End layer (exclusive), default 13
        max_frames: Maximum frames to keep

    Returns:
        Embeddings tensor [T, 1024] where T is number of frames
    """
    layer_start = layer_start or MODEL_CONFIG["muq_layer_start"]
    layer_end = layer_end or MODEL_CONFIG["muq_layer_end"]
    max_frames = max_frames or MODEL_CONFIG["max_frames"]

    # MuQ expects [B, samples] tensor
    wavs = torch.tensor(audio).unsqueeze(0).to(cache.device)

    # Get hidden states from all layers
    outputs = cache.muq_model(wavs, output_hidden_states=True)

    # Average layers 9-12 (indices in hidden_states tuple)
    # hidden_states is tuple of [B, T, D] tensors
    hidden_states = outputs.hidden_states[layer_start:layer_end]
    embeddings = torch.stack(hidden_states, dim=0).mean(dim=0).squeeze(0)

    if embeddings.shape[0] > max_frames:
        embeddings = embeddings[:max_frames]

    return embeddings


@torch.no_grad()
def predict_with_ensemble(
    embeddings: torch.Tensor,
    cache: ModelCache,
) -> np.ndarray:
    """Get predictions from 4-fold ensemble of A1-Max heads.

    Each head uses attention pooling on frame-level embeddings,
    then encoder + regression head to predict 6-dim scores.

    Args:
        embeddings: Frame embeddings [T, D] from MuQ
        cache: Model cache with loaded heads

    Returns:
        Averaged predictions [6] across all folds
    """
    if not cache.muq_heads:
        raise RuntimeError("No A1-Max heads loaded in cache")

    # Get predictions from each fold head
    predictions = []
    for head in cache.muq_heads:
        pred = head(embeddings).cpu().numpy()
        predictions.append(pred)

    return np.mean(predictions, axis=0)