"""Speaker classifier scaffold for multi-task training and evaluation. This module provides a small PyTorch `SpeakerClassifier` that maps embeddings (or pooled encoder outputs) to speaker logits, plus helpers to build speaker mappings from manifests. """ import json from pathlib import Path from typing import Dict, List try: import torch import torch.nn as nn except Exception: torch = None nn = None class SpeakerClassifier: """A light-weight wrapper that exposes an API-compatible classifier. If PyTorch is available, `SpeakerClassifier.model` is a `nn.Module`. Otherwise this is a placeholder to keep the dependency optional in tests. """ def __init__(self, input_dim: int, num_speakers: int, dropout: float = 0.1): self.input_dim = input_dim self.num_speakers = num_speakers self.dropout = dropout if torch is not None and nn is not None: self.model = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(input_dim, num_speakers), ) else: self.model = None def forward(self, x): if self.model is None: raise RuntimeError("PyTorch not available for SpeakerClassifier") return self.model(x) def build_speaker_map(manifest_paths: List[str]) -> Dict[str, int]: """Read JSONL manifest(s) and return a speaker->id mapping. The manifest format: each line is JSON with optional "speaker" key. Labels are returned in deterministic sorted order. """ speakers = set() for p in manifest_paths: pth = Path(p) if not pth.exists(): continue with open(pth, "r", encoding="utf-8") as fh: for line in fh: line = line.strip() if not line: continue try: obj = json.loads(line) except Exception: continue spk = obj.get("speaker") if spk is not None: speakers.add(str(spk)) sorted_spks = sorted(speakers) return {s: i for i, s in enumerate(sorted_spks)}