quynhthames commited on
Commit
eeeeb9c
·
1 Parent(s): 81607f6
Files changed (2) hide show
  1. app.py +334 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import json
5
+ import html
6
+ from itertools import groupby
7
+ from sentence_transformers import SentenceTransformer, util
8
+ from underthesea import sent_tokenize
9
+ from transformers import pipeline
10
+ import tempfile
11
+ import os
12
+ import gc
13
+
14
+ # === Setup Models & Tokens ===
15
+
16
+ HF_TOKEN = "REMOVED_SECRET"
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load whisper lazily inside function to save startup time
21
+ whisper_model = None
22
+
23
+ # Speaker diarization pipeline
24
+ from pyannote.audio import Pipeline
25
+ diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=HF_TOKEN)
26
+ diarization_pipeline.to(device)
27
+
28
+ # Vietnamese punctuation corrector
29
+ corrector = pipeline("text2text-generation", model="bmd1905/vietnamese-correction-v2", device=0 if torch.cuda.is_available() else -1)
30
+
31
+ # SentenceTransformer for embeddings
32
+ embedding_model = SentenceTransformer("VoVanPhuc/sup-SimCSE-VietNamese-phobert-base", device=str(device))
33
+
34
+ # Cache for embeddings and transcript
35
+ cached_transcript_segments = None
36
+ cached_embeddings = None
37
+
38
+ # Dynamic color generator
39
+ def generate_color_palette(n):
40
+ import colorsys
41
+ hues = np.linspace(0, 1, n, endpoint=False)
42
+ colors = []
43
+ for h in hues:
44
+ r, g, b = colorsys.hsv_to_rgb(h, 0.6, 0.9)
45
+ colors.append(f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, 0.5)")
46
+ return colors
47
+
48
+ # Step 1: Audio conversion
49
+ def convert_to_wav(audio_file):
50
+ import subprocess
51
+ if not audio_file:
52
+ return None, "No audio provided."
53
+ input_path = audio_file.name
54
+ output_path = tempfile.mktemp(suffix=".wav")
55
+ # Convert only if not wav or not correct sample rate
56
+ try:
57
+ # ffmpeg command: 1 channel, 16000 Hz sample rate, wav format
58
+ subprocess.run(
59
+ ["ffmpeg", "-y", "-i", input_path, "-ac", "1", "-ar", "16000", output_path],
60
+ stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
61
+ except Exception as e:
62
+ return None, f"Error converting audio: {e}"
63
+ return output_path, "Audio converted to WAV."
64
+
65
+ # Step 2: Transcription
66
+ def transcribe_audio(wav_path, progress=gr.Progress()):
67
+ global whisper_model
68
+ if whisper_model is None:
69
+ import whisper
70
+ whisper_model = whisper.load_model("large", device=str(device))
71
+ progress(0.1, desc="Transcribing audio...")
72
+ result = whisper_model.transcribe(wav_path, language="vi")
73
+ progress(1.0, desc="Transcription complete.")
74
+ return result
75
+
76
+ # Step 3: Diarization
77
+ def diarize_audio(wav_path, progress=gr.Progress()):
78
+ progress(0.1, desc="Running diarization...")
79
+ diarization = diarization_pipeline(wav_path)
80
+ progress(1.0, desc="Diarization complete.")
81
+ return diarization
82
+
83
+ def merge_transcript_with_speakers(transcript_segments, diarization):
84
+ merged = []
85
+ for seg in transcript_segments:
86
+ start = seg["start"]
87
+ end = seg["end"]
88
+ text = seg["text"].strip()
89
+ speaker = "Unknown"
90
+ max_overlap = 0
91
+ for turn, _, label in diarization.itertracks(yield_label=True):
92
+ overlap = max(0, min(end, turn.end) - max(start, turn.start))
93
+ if overlap > max_overlap:
94
+ speaker = label
95
+ max_overlap = overlap
96
+ merged.append((speaker, text))
97
+ grouped = [
98
+ {"speaker": speaker, "text": ' '.join(text for _, text in group)}
99
+ for speaker, group in groupby(merged, key=lambda x: x[0])
100
+ ]
101
+ return grouped
102
+
103
+ # Step 4: Punctuation correction
104
+ def correct_punctuation(transcript, progress=gr.Progress()):
105
+ MAX_LENGTH = 4096
106
+ BATCH_SIZE = 8
107
+ texts = [turn['text'] for turn in transcript]
108
+
109
+ def batch(lst, batch_size):
110
+ for i in range(0, len(lst), batch_size):
111
+ yield lst[i:i + batch_size]
112
+
113
+ corrected_texts = []
114
+ total_batches = (len(texts) + BATCH_SIZE - 1) // BATCH_SIZE
115
+ for i, text_batch in enumerate(batch(texts, BATCH_SIZE)):
116
+ progress(i / total_batches, desc="Correcting punctuation...")
117
+ predictions = corrector(text_batch, max_length=MAX_LENGTH)
118
+ corrected_texts.extend([p['generated_text'] for p in predictions])
119
+ progress(1.0, desc="Punctuation correction complete.")
120
+
121
+ for turn, corrected_text in zip(transcript, corrected_texts):
122
+ turn['text'] = corrected_text
123
+ return transcript
124
+
125
+ # Step 5: Content analysis (keyword highlighting)
126
+ def highlight_transcript(transcript, keywords, percentile):
127
+ global cached_transcript_segments, cached_embeddings
128
+ if cached_transcript_segments is None or cached_transcript_segments != transcript:
129
+ # Flatten sentences
130
+ flattened = []
131
+ for idx, turn in enumerate(transcript):
132
+ # Only keep sentences with enough words
133
+ def is_relevant_sentence(text, min_word_count=6):
134
+ words = [w for w in text.split() if w.isalpha()]
135
+ return len(words) >= min_word_count
136
+ if is_relevant_sentence(turn["text"]):
137
+ for sent in sent_tokenize(turn["text"]):
138
+ sent = sent.strip()
139
+ if is_relevant_sentence(sent):
140
+ flattened.append({"speaker": turn["speaker"], "text": sent, "turn_idx": idx})
141
+ cached_transcript_segments = transcript
142
+
143
+ # Sliding windows
144
+ def sliding_windows(sentences, window_size=2, step=1):
145
+ windows = []
146
+ for i in range(0, len(sentences) - window_size + 1, step):
147
+ chunk = sentences[i:i + window_size]
148
+ windows.append({
149
+ "start_idx": i,
150
+ "end_idx": i + window_size,
151
+ "speakers": [s["speaker"] for s in chunk],
152
+ "text": " ".join(s["text"] for s in chunk)
153
+ })
154
+ return windows
155
+
156
+ windows = sliding_windows(flattened)
157
+ window_texts = [w["text"] for w in windows]
158
+ cached_embeddings = embedding_model.encode(window_texts, convert_to_tensor=True)
159
+ else:
160
+ # reuse cached_embeddings
161
+ flattened = []
162
+ for idx, turn in enumerate(transcript):
163
+ def is_relevant_sentence(text, min_word_count=6):
164
+ words = [w for w in text.split() if w.isalpha()]
165
+ return len(words) >= min_word_count
166
+ if is_relevant_sentence(turn["text"]):
167
+ for sent in sent_tokenize(turn["text"]):
168
+ sent = sent.strip()
169
+ if is_relevant_sentence(sent):
170
+ flattened.append({"speaker": turn["speaker"], "text": sent, "turn_idx": idx})
171
+ windows = []
172
+ for i in range(len(flattened)-1):
173
+ chunk = flattened[i:i+2]
174
+ windows.append({
175
+ "start_idx": i,
176
+ "end_idx": i+2,
177
+ "speakers": [s["speaker"] for s in chunk],
178
+ "text": " ".join(s["text"] for s in chunk)
179
+ })
180
+
181
+ # Generate colors dynamically for keywords
182
+ unique_keywords = list(set([k.strip().lower() for k in keywords if k.strip() != ""]))
183
+ colors = generate_color_palette(len(unique_keywords))
184
+ keyword_color_map = dict(zip(unique_keywords, colors))
185
+
186
+ matched_windows = []
187
+ for keyword in unique_keywords:
188
+ if not keyword:
189
+ continue
190
+ keyword_embedding = embedding_model.encode([keyword], convert_to_tensor=True)
191
+ sims = util.cos_sim(cached_embeddings, keyword_embedding).squeeze()
192
+ top_indices, threshold = auto_top_k(sims.cpu().numpy(), percentile=percentile)
193
+ for i in top_indices:
194
+ matched_windows.append({
195
+ "start": windows[i]["start_idx"],
196
+ "end": windows[i]["end_idx"],
197
+ "keywords": [{
198
+ "keyword": keyword,
199
+ "color": keyword_color_map[keyword],
200
+ "score": sims[i].item()
201
+ }]
202
+ })
203
+
204
+ # Merge overlapping windows
205
+ matched_windows.sort(key=lambda x: x["start"])
206
+ merged = []
207
+ for w in matched_windows:
208
+ if not merged or w["start"] > merged[-1]["end"]:
209
+ merged.append(w.copy())
210
+ else:
211
+ merged[-1]["end"] = max(merged[-1]["end"], w["end"])
212
+ merged[-1]["keywords"].extend(w["keywords"])
213
+
214
+ # Build highlight map
215
+ highlight_map = {}
216
+ for mw in merged:
217
+ for idx in range(mw["start"], mw["end"]):
218
+ sent_info = flattened[idx]
219
+ turn_idx = sent_info["turn_idx"]
220
+ if turn_idx not in highlight_map:
221
+ highlight_map[turn_idx] = []
222
+ highlight_map[turn_idx].extend(mw["keywords"])
223
+
224
+ # Compose HTML transcript with highlights and speaker colors & tooltip similarity scores
225
+ # Assign a color per speaker
226
+ speakers = list(set([turn["speaker"] for turn in transcript]))
227
+ speaker_colors = generate_color_palette(len(speakers))
228
+ speaker_color_map = dict(zip(speakers, speaker_colors))
229
+
230
+ html_lines = []
231
+ for i, turn in enumerate(transcript):
232
+ sp = turn["speaker"]
233
+ sp_color = speaker_color_map.get(sp, "black")
234
+ text = html.escape(turn["text"])
235
+ # Apply highlights for keywords in this turn
236
+ if i in highlight_map:
237
+ keywords_info = highlight_map[i]
238
+ # Combine same keywords (by name)
239
+ combined = {}
240
+ for k in keywords_info:
241
+ combined[k["keyword"]] = k
242
+ # Sort keywords by score desc
243
+ sorted_kw = sorted(combined.values(), key=lambda x: x["score"], reverse=True)
244
+ tooltip_text = ", ".join(f'{kw["keyword"]} ({kw["score"]:.3f})' for kw in sorted_kw)
245
+ # Wrap keywords with span colored background
246
+ for kw in sorted_kw:
247
+ # Replace all keyword occurrences (case insensitive)
248
+ text = replace_case_insensitive(text, kw["keyword"], f'<span style="background-color:{kw["color"]};" title="{tooltip_text}">{kw["keyword"]}</span>')
249
+ # Speaker label with color
250
+ html_lines.append(f'<p><b><span style="color:{sp_color};">Speaker: {sp}</span></b><br>{text}</p>')
251
+ else:
252
+ html_lines.append(f'<p><b><span style="color:{sp_color};">Speaker: {sp}</span></b><br>{text}</p>')
253
+ final_html = "<br>".join(html_lines)
254
+ return final_html
255
+
256
+ def auto_top_k(similarities, percentile=90):
257
+ threshold = np.percentile(similarities, percentile)
258
+ top_indices = np.where(similarities >= threshold)[0]
259
+ return top_indices, threshold
260
+
261
+ def replace_case_insensitive(text, keyword, replacement):
262
+ import re
263
+ pattern = re.compile(re.escape(keyword), re.IGNORECASE)
264
+ return pattern.sub(replacement, text)
265
+
266
+ # Main app function
267
+ def run_pipeline(audio_file, keywords_raw, percentile, transcript_input, proceed_clicked):
268
+ if not proceed_clicked:
269
+ return "", "Waiting for input...", None
270
+ keywords = [k.strip().lower() for k in keywords_raw.split(",") if k.strip() != ""]
271
+ if transcript_input.strip():
272
+ # Use pasted transcript - parse as JSON or text with speaker info?
273
+ # For now, assume JSON list [{"speaker":"spk1","text":"..."}]
274
+ try:
275
+ transcript = json.loads(transcript_input)
276
+ except:
277
+ return "", "Invalid transcript JSON format.", None
278
+ transcript_html = highlight_transcript(transcript, keywords, percentile)
279
+ # Prepare JSON for download
280
+ transcript_json_str = json.dumps(transcript, ensure_ascii=False, indent=2)
281
+ return transcript_html, "Loaded transcript and analyzed.", gr.File.update(value=None)
282
+ if not audio_file:
283
+ return "", "Please upload audio file or paste transcript.", None
284
+ status = "Converting audio..."
285
+ wav_path, msg = convert_to_wav(audio_file)
286
+ if not wav_path:
287
+ return "", msg, None
288
+ status = "Transcribing audio..."
289
+ result = transcribe_audio(wav_path)
290
+ segments = result["segments"]
291
+ # Diarize
292
+ status = "Diarizing audio..."
293
+ diarization = diarize_audio(wav_path)
294
+ # Merge transcript with speakers
295
+ merged = merge_transcript_with_speakers(segments, diarization)
296
+ # Punctuation correction
297
+ status = "Correcting punctuation..."
298
+ merged = correct_punctuation(merged)
299
+ # Content analysis + highlighting
300
+ status = "Highlighting transcript..."
301
+ transcript_html = highlight_transcript(merged, keywords, percentile)
302
+ # Save JSON for download
303
+ transcript_json_str = json.dumps(merged, ensure_ascii=False, indent=2)
304
+
305
+ # Cleanup temp files
306
+ try:
307
+ os.remove(wav_path)
308
+ except:
309
+ pass
310
+ gc.collect()
311
+ return transcript_html, "Processing complete.", gr.File.update(value=None)
312
+
313
+ with gr.Blocks() as demo:
314
+ gr.Markdown("## Vietnamese Audio Transcript & Keyword Analysis")
315
+
316
+ with gr.Row():
317
+ with gr.Column(scale=2):
318
+ audio_input = gr.Audio(label="Upload or record audio (16kHz mono WAV recommended)", source="upload", type="file")
319
+ transcript_input = gr.Textbox(label="Or paste final transcript JSON (skip upload & transcription)", lines=6, placeholder='Paste JSON here')
320
+ keywords_input = gr.Textbox(label="Enter keywords separated by commas", value="hoa hồng, chiến lược giá")
321
+ percentile_slider = gr.Slider(50, 100, value=90, step=1, label="Similarity percentile threshold for keyword matching")
322
+ proceed_btn = gr.Button("Proceed")
323
+
324
+ with gr.Column(scale=3):
325
+ output_html = gr.HTML()
326
+ status_text = gr.Textbox(label="Status", interactive=False)
327
+
328
+ proceed_btn.click(
329
+ run_pipeline,
330
+ inputs=[audio_input, keywords_input, percentile_slider, transcript_input, proceed_btn],
331
+ outputs=[output_html, status_text, None]
332
+ )
333
+
334
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/openai/whisper.git
2
+ pyannote.audio
3
+ sentence_transformers
4
+ underthesea
5
+ pyvi
6
+ torch
7
+ numpy
8
+ transformers
9
+ gradio