File size: 2,245 Bytes
fda93d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Speaker classifier scaffold for multi-task training and evaluation.



This module provides a small PyTorch `SpeakerClassifier` that maps embeddings

(or pooled encoder outputs) to speaker logits, plus helpers to build speaker

mappings from manifests.

"""

import json
from pathlib import Path
from typing import Dict, List

try:
    import torch
    import torch.nn as nn
except Exception:
    torch = None
    nn = None


class SpeakerClassifier:
    """A light-weight wrapper that exposes an API-compatible classifier.



    If PyTorch is available, `SpeakerClassifier.model` is a `nn.Module`.

    Otherwise this is a placeholder to keep the dependency optional in tests.

    """

    def __init__(self, input_dim: int, num_speakers: int, dropout: float = 0.1):
        self.input_dim = input_dim
        self.num_speakers = num_speakers
        self.dropout = dropout
        if torch is not None and nn is not None:
            self.model = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(input_dim, num_speakers),
            )
        else:
            self.model = None

    def forward(self, x):
        if self.model is None:
            raise RuntimeError("PyTorch not available for SpeakerClassifier")
        return self.model(x)


def build_speaker_map(manifest_paths: List[str]) -> Dict[str, int]:
    """Read JSONL manifest(s) and return a speaker->id mapping.



    The manifest format: each line is JSON with optional "speaker" key.

    Labels are returned in deterministic sorted order.

    """
    speakers = set()
    for p in manifest_paths:
        pth = Path(p)
        if not pth.exists():
            continue
        with open(pth, "r", encoding="utf-8") as fh:
            for line in fh:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception:
                    continue
                spk = obj.get("speaker")
                if spk is not None:
                    speakers.add(str(spk))
    sorted_spks = sorted(speakers)
    return {s: i for i, s in enumerate(sorted_spks)}