Capstone04 commited on
Commit
e7d2bdc
·
verified ·
1 Parent(s): cc204c2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +25 -76
asr_diarization/pipeline.py CHANGED
@@ -48,47 +48,9 @@ class ASR_Diarization:
48
  for t, _, spk in diarization.itertracks(yield_label=True)
49
  ]
50
 
51
- # def run_transcription(self, audio_path, diar_json):
52
- # audio, sr = torchaudio.load(audio_path)
53
- # merged_segments = []
54
- # speaker_segments = {}
55
-
56
- # for seg in diar_json:
57
- # segment_start, segment_end, spk = seg["segment_start"], seg["segment_end"], seg["speaker"]
58
- # start_sample, end_sample = int(segment_start * sr), int(segment_end * sr)
59
- # chunk = audio[0, start_sample:end_sample].numpy()
60
-
61
- # reduced = nr.reduce_noise(y=chunk, sr=sr)
62
- # result = self.asr_pipeline(reduced)
63
-
64
- # tokens = []
65
- # if "chunks" in result:
66
- # for word_info in result["chunks"]:
67
- # start_ts, end_ts = word_info.get("timestamp", (None, None)) or (None, None)
68
- # tokens.append({
69
- # "tag": "w",
70
- # "start": start_ts,
71
- # "end": end_ts,
72
- # "text": word_info["text"]
73
- # })
74
-
75
- # seg_dict = {
76
- # "speaker": spk,
77
- # "segment_start": segment_start,
78
- # "segment_end": segment_end,
79
- # "tokens": tokens
80
- # }
81
- # merged_segments.append(seg_dict)
82
-
83
- # if spk not in speaker_segments:
84
- # speaker_segments[spk] = []
85
- # speaker_segments[spk].append(seg_dict)
86
-
87
- # return merged_segments, list(speaker_segments.keys())
88
-
89
  def run_transcription(self, audio_path, diar_json):
90
  audio, sr = torchaudio.load(audio_path)
91
- all_word_segments = []
92
  speaker_segments = {}
93
 
94
  for seg in diar_json:
@@ -99,45 +61,32 @@ class ASR_Diarization:
99
  reduced = nr.reduce_noise(y=chunk, sr=sr)
100
  result = self.asr_pipeline(reduced)
101
 
 
102
  if "chunks" in result:
103
  for word_info in result["chunks"]:
104
- # Each word or token gets its own mini segment
105
- start_ts, end_ts = None, None
106
-
107
- if isinstance(word_info.get("timestamp"), (list, tuple)):
108
- start_ts, end_ts = word_info["timestamp"]
109
- elif isinstance(word_info.get("timestamp"), (float, int)):
110
- start_ts = word_info["timestamp"]
111
- end_ts = start_ts
112
-
113
- if start_ts is None:
114
- continue
115
-
116
- # Shift timestamps to align with full audio
117
- abs_start = segment_start + start_ts
118
- abs_end = segment_start + end_ts
119
-
120
- word_segment = {
121
- "speaker": spk,
122
- "segment_start": abs_start,
123
- "segment_end": abs_end,
124
- "tokens": [
125
- {
126
- "tag": "w",
127
- "start": abs_start,
128
- "end": abs_end,
129
- "text": word_info["text"].strip()
130
- }
131
- ]
132
- }
133
-
134
- all_word_segments.append(word_segment)
135
-
136
- if spk not in speaker_segments:
137
- speaker_segments[spk] = []
138
- speaker_segments[spk].append(word_segment)
139
-
140
- return all_word_segments, list(speaker_segments.keys())
141
 
142
  def run_pipeline(self, audio_path, output_dir=None, base_name=None,
143
  ref_rttm=None, ref_json=None):
 
48
  for t, _, spk in diarization.itertracks(yield_label=True)
49
  ]
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def run_transcription(self, audio_path, diar_json):
52
  audio, sr = torchaudio.load(audio_path)
53
+ merged_segments = []
54
  speaker_segments = {}
55
 
56
  for seg in diar_json:
 
61
  reduced = nr.reduce_noise(y=chunk, sr=sr)
62
  result = self.asr_pipeline(reduced)
63
 
64
+ tokens = []
65
  if "chunks" in result:
66
  for word_info in result["chunks"]:
67
+ start_ts, end_ts = word_info.get("timestamp", (None, None)) or (None, None)
68
+ tokens.append({
69
+ "tag": "w",
70
+ "start": start_ts,
71
+ "end": end_ts,
72
+ "text": word_info["text"]
73
+ })
74
+
75
+ seg_dict = {
76
+ "speaker": spk,
77
+ "segment_start": segment_start,
78
+ "segment_end": segment_end,
79
+ "tokens": tokens
80
+ }
81
+ merged_segments.append(seg_dict)
82
+ print("Sample merged segment:", merged_segments[0])
83
+
84
+
85
+ if spk not in speaker_segments:
86
+ speaker_segments[spk] = []
87
+ speaker_segments[spk].append(seg_dict)
88
+
89
+ return merged_segments, list(speaker_segments.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def run_pipeline(self, audio_path, output_dir=None, base_name=None,
92
  ref_rttm=None, ref_json=None):