File size: 2,311 Bytes
9f76952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73


import os
import torch
import torchaudio
import torch.serialization
from pyannote.core import Segment, Timeline

# --- 1. MONKEY PATCH (Fixes PyTorch 2.6 Security Error) ---
original_load = torch.serialization.load
def forced_load(f, map_location=None, pickle_module=None, **kwargs):
    kwargs['weights_only'] = False
    return original_load(f, map_location=map_location, pickle_module=pickle_module, **kwargs)
torch.load = forced_load
torch.serialization.load = forced_load
# ---------------------------------------------------------

#Model : neural network, Segmentation : training logic, get_protocol : dataset loader, pl.Trainer : training engine

from pyannote.audio import Model
from pyannote.audio.tasks import Segmentation
from pyannote.database import get_protocol, FileFinder
import pytorch_lightning as pl

os.environ["PYANNOTE_DATABASE_CONFIG"] = "database.yml"

def train_segmentation():
    # 2. PREPROCESSORS
    def get_annotated(file):
        info = torchaudio.info(file["audio"])
        # Calculate duration: total frames / sample rate
        duration = info.num_frames / info.sample_rate
        # Return the 'Timeline' object the library is looking for
        return Timeline([Segment(0, duration)])

    preprocessors = {
        "audio": FileFinder(),
        "annotated": get_annotated,
    }

    # 3. LOAD PROTOCOL
    print("Loading Hindi-Bhojpuri Protocol...")
    protocol = get_protocol(
        'HindiBhojpuri.SpeakerDiarization.Segmentation', 
        preprocessors=preprocessors
    )

    # 4. SETUP TASK
    seg_task = Segmentation(
        protocol, 
        duration=2.0, 
        batch_size=4, 
        num_workers=0 
    )

    # 5. LOAD MODEL - Start from an English-trained segmentation model, and adapt it to Hindi/Bhojpuri.” This is transfer learning, not training from scratch.
    print("Attempting to load model...")
    model = Model.from_pretrained("pyannote/segmentation-3.0")
    model.task = seg_task

    # 6. TRAINER
    trainer = pl.Trainer(
        accelerator="cpu", 
        max_epochs=5,
        default_root_dir="training_results"
    )

    # 7. START
    print("--- Starting Fine-tuning ---")
    trainer.fit(model)

if __name__ == "__main__":
    train_segmentation()