File size: 3,898 Bytes
a361db3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Data structures for pipeline stages."""

from dataclasses import dataclass, asdict, field
from typing import Optional, List
import json


@dataclass
class TalkerInfo:
    """Information about a single speaker."""
    id: int
    label: str                          # e.g. "SPEAKER_00"
    gender: Optional[str] = None        # "male" / "female" / "unknown" / "ambiguous"
    mean_f0_hz: Optional[float] = None  # Fundamental frequency in Hz
    transcript: Optional[str] = None    # Transcribed speech
    language: Optional[str] = None      # Detected language code (e.g. "en", "da")
    wav_path: Optional[str] = None      # Path to extracted source WAV
    is_toi: bool = False                # Talker of interest flag
    toi_reason: Optional[str] = None    # Reasoning for ToI selection
    direction_deg: Optional[float] = None  # Direction of arrival in degrees
    energy: Optional[float] = None      # Energy level
    selection_score: Optional[float] = None  # Numerical score for ToI selection

    def to_dict(self):
        """Convert to dictionary, excluding None values for cleaner JSON."""
        return {k: v for k, v in asdict(self).items() if v is not None}


@dataclass
class PipelineOutput:
    """Complete output from pipeline execution."""
    input_file: str
    approach: str                      # "ica" / "ica_deeplearning" / "frankenstein"
    duration_seconds: float
    sample_rate: int
    n_speakers: int
    talker_of_interest: int            # Source index (1-indexed)
    sources: List[TalkerInfo] = field(default_factory=list)
    
    # Performance metrics
    execution_time_seconds: Optional[float] = None
    separation_method: Optional[str] = None
    doa_method: Optional[str] = None
    gender_method: Optional[str] = None
    asr_model: Optional[str] = None
    
    # Optional: Processing chain details
    processing_notes: Optional[str] = None

    def to_dict(self):
        """Convert to dictionary for JSON serialization."""
        return {
            "input_file": self.input_file,
            "approach": self.approach,
            "duration_seconds": round(self.duration_seconds, 2),
            "sample_rate": self.sample_rate,
            "n_speakers": self.n_speakers,
            "talker_of_interest": self.talker_of_interest,
            "execution_time_seconds": round(self.execution_time_seconds, 2) if self.execution_time_seconds else None,
            "processing_methods": {
                "separation": self.separation_method,
                "direction_of_arrival": self.doa_method,
                "gender_classification": self.gender_method,
                "asr_model": self.asr_model,
            },
            "sources": [s.to_dict() for s in self.sources],
            "notes": self.processing_notes,
        }

    def to_json(self, indent: int = 2) -> str:
        """Convert to JSON string."""
        return json.dumps(self.to_dict(), indent=indent)

    @classmethod
    def from_dict(cls, data: dict):
        """Reconstruct from dictionary."""
        sources = [TalkerInfo(**s) for s in data.get("sources", [])]
        return cls(
            input_file=data["input_file"],
            approach=data["approach"],
            duration_seconds=data["duration_seconds"],
            sample_rate=data["sample_rate"],
            n_speakers=data["n_speakers"],
            talker_of_interest=data["talker_of_interest"],
            sources=sources,
            execution_time_seconds=data.get("execution_time_seconds"),
            separation_method=data.get("processing_methods", {}).get("separation"),
            doa_method=data.get("processing_methods", {}).get("direction_of_arrival"),
            gender_method=data.get("processing_methods", {}).get("gender_classification"),
            asr_model=data.get("processing_methods", {}).get("asr_model"),
            processing_notes=data.get("notes"),
        )