Spaces:
Sleeping
Sleeping
| import logging | |
| import subprocess | |
| import time | |
| from pathlib import Path | |
| from typing import List, Dict, Optional, Tuple | |
| from dataclasses import dataclass | |
| from collections import defaultdict, Counter | |
| import numpy as np | |
| import librosa | |
| from app.core.config import get_settings | |
| from app.services.transcription import TranscriptionService | |
| from app.services.alignment import AlignmentService | |
| from app.services.transcription import WordTimestamp | |
| from app.services.emo import EmotionService | |
| from app.services.diarization import ( | |
| DiarizationService, | |
| SpeakerSegment, | |
| DiarizationResult, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| settings = get_settings() | |
| class TranscriptSegment: | |
| """A transcribed segment with speaker info.""" | |
| start: float | |
| end: float | |
| speaker: str | |
| role: Optional[str] | |
| text: str | |
| emotion: Optional[str] = None | |
| icon: Optional[str] = None | |
| class EmotionPoint: | |
| time: float | |
| emotion: str | |
| icon: Optional[str] | |
| class EmotionChange: | |
| time: float | |
| emotion_from: str | |
| emotion_to: str | |
| icon_from: Optional[str] = None | |
| icon_to: Optional[str] = None | |
| class ProcessingResult: | |
| """Result of audio processing.""" | |
| segments: List[TranscriptSegment] | |
| speaker_count: int | |
| duration: float | |
| processing_time: float | |
| speakers: List[str] | |
| roles: Dict[str, str] | |
| txt_content: str = "" | |
| csv_content: str = "" | |
| emotion_timeline: List[EmotionPoint] = None | |
| emotion_changes: List[EmotionChange] = None | |
| def normalize_asr_result(result: dict): | |
| words = [] | |
| for w in result.get("words", []): | |
| word = ( | |
| w.get("word", "") | |
| .strip() | |
| ) | |
| if not word: | |
| continue | |
| words.append( | |
| { | |
| "word": word, | |
| "start": float(w["start"]), | |
| "end": float(w["end"]), | |
| "speaker": w.get("speaker"), | |
| "confidence": float( | |
| w.get("confidence", 1.0) | |
| ), | |
| } | |
| ) | |
| text = result.get("text", "").strip() | |
| return text, words | |
| def guess_speaker_by_overlap(start, end, diar_segments): | |
| best_spk = None | |
| best_overlap = 0.0 | |
| for seg in diar_segments: | |
| overlap = max(0.0, min(end, seg.end) - max(start, seg.start)) | |
| if overlap > best_overlap: | |
| best_overlap = overlap | |
| best_spk = seg.speaker | |
| return best_spk or diar_segments[0].speaker | |
| def convert_audio_to_wav(audio_path: Path) -> Path: | |
| """Convert any audio to WAV 16kHz Mono using ffmpeg.""" | |
| output_path = audio_path.parent / f"{audio_path.stem}_processed.wav" | |
| if output_path.exists(): | |
| output_path.unlink() | |
| command = [ | |
| "ffmpeg", | |
| "-i", | |
| str(audio_path), | |
| "-ar", | |
| "16000", | |
| "-ac", | |
| "1", | |
| "-y", | |
| str(output_path), | |
| ] | |
| try: | |
| subprocess.run( | |
| command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL | |
| ) | |
| logger.info(f"Converted audio to WAV: {output_path}") | |
| return output_path | |
| except subprocess.CalledProcessError as e: | |
| logger.error(f"FFmpeg conversion failed: {e}") | |
| return audio_path | |
| def format_timestamp(seconds: float) -> str: | |
| m = int(seconds // 60) | |
| s = seconds % 60 | |
| return f"{m:02d}:{s:06.3f}" | |
| def merge_consecutive_segments( | |
| segments: List[SpeakerSegment], | |
| max_gap: float = 0.80, | |
| max_overlap: float = 0.15, | |
| ) -> List[SpeakerSegment]: | |
| if not segments: | |
| return [] | |
| segments = sorted( | |
| segments, | |
| key=lambda x: x.start | |
| ) | |
| merged = [segments[0]] | |
| for seg in segments[1:]: | |
| prev = merged[-1] | |
| gap = seg.start - prev.end | |
| if ( | |
| seg.speaker == prev.speaker | |
| and gap >= -max_overlap | |
| and gap <= max_gap | |
| ): | |
| prev.end = max( | |
| prev.end, | |
| seg.end | |
| ) | |
| else: | |
| merged.append(seg) | |
| return merged | |
| def overlap_prefix(a: str, b: str, n: int = 12) -> bool: | |
| if not a or not b: | |
| return False | |
| a = a.strip().lower() | |
| b = b.strip().lower() | |
| return a[:n] in b or b[:n] in a | |
| class Processor: | |
| async def process_audio( | |
| cls, | |
| audio_path: Path, | |
| model_name: str = "PhoWhisper Lora Finetuned", | |
| language="vi", | |
| merge_segments: bool = True | |
| ) -> ProcessingResult: | |
| import asyncio | |
| t0 = time.time() | |
| EmotionService.preload_model() | |
| # 1: Convert to WAV | |
| logger.info("Step 1: Converting audio to WAV 16kHz...") | |
| wav_path = await asyncio.get_event_loop().run_in_executor( | |
| None, convert_audio_to_wav, audio_path | |
| ) | |
| # 2: Load audio | |
| y, sr = librosa.load(wav_path, sr=16000, mono=True) | |
| if y.size == 0: | |
| raise ValueError("Empty audio") | |
| duration = len(y) / sr | |
| # 3: Diarization | |
| logger.info("Step 3: Running diarization...") | |
| diarization: DiarizationResult = await DiarizationService.diarize_async( | |
| wav_path | |
| ) | |
| diarization_segments = diarization.segments or [] | |
| if not diarization_segments: | |
| diarization_segments = [SpeakerSegment(0.0, duration, "SPEAKER_0")] | |
| speakers = ["SPEAKER_0"] | |
| roles = {"SPEAKER_0": "KH"} | |
| diarization_segments.sort(key=lambda x: x.start) | |
| if merge_segments and diarization_segments: | |
| logger.info("Step 4: Merging consecutive segments...") | |
| diarization_segments = merge_consecutive_segments(diarization_segments) | |
| # 4. Normalize speakers | |
| raw_speakers = sorted({seg.speaker for seg in diarization_segments}) | |
| speaker_map = {spk: f"Speaker {i+1}" for i, spk in enumerate(raw_speakers)} | |
| speakers = list(speaker_map.values()) | |
| raw_roles = diarization.roles or {} | |
| roles = {} | |
| for raw_spk, label in speaker_map.items(): | |
| roles[label] = raw_roles.get(raw_spk, "KH") | |
| logger.info(f"roles(mapped) = {roles}") | |
| # 7: Transcribe segments after diarization | |
| logger.info("Step 7: Running ASR with external VAD batch...") | |
| asr_result = await TranscriptionService.transcribe_with_words_async( | |
| audio_array=y, | |
| model_name=model_name, | |
| language=language, | |
| vad_options=False | |
| ) | |
| text, raw_words = normalize_asr_result(asr_result) | |
| if not raw_words: | |
| processed_segments = [ | |
| TranscriptSegment( | |
| start=0.0, | |
| end=duration, | |
| speaker=speakers[0], | |
| role=roles[speakers[0]], | |
| text="(No speech detected)", | |
| ) | |
| ] | |
| else: | |
| # ===== CONVERT TO WordTimestamp ===== | |
| word_objs: List[WordTimestamp] = [] | |
| for w in raw_words: | |
| spk = w.get("speaker") | |
| if spk is None: | |
| spk = guess_speaker_by_overlap( | |
| w["start"], w["end"], diarization_segments | |
| ) | |
| word_objs.append( | |
| WordTimestamp( | |
| word=w["word"], | |
| start=w["start"], | |
| end=w["end"], | |
| speaker=spk, | |
| confidence=w.get("confidence", 1.0) | |
| ) | |
| ) | |
| word_objs.sort(key=lambda x: x.start) | |
| # ===== ALIGNMENT ===== | |
| aligned_segments = AlignmentService.align_precision( | |
| word_objs, | |
| diarization_segments | |
| ) | |
| processed_segments = [] | |
| if not aligned_segments: | |
| vote = [w.speaker for w in word_objs if w.speaker] | |
| if vote: | |
| raw_spk = Counter(vote).most_common(1)[0][0] | |
| else: | |
| raw_spk = diarization_segments[0].speaker | |
| label = speaker_map.get(raw_spk, "Speaker 1") | |
| processed_segments.append( | |
| TranscriptSegment(0, duration, label, roles[label], text) | |
| ) | |
| else: | |
| for seg in aligned_segments: | |
| raw_spk = seg.speaker | |
| label = speaker_map.get(raw_spk, "Speaker 1") | |
| role = roles.get(label, "KH") | |
| processed_segments.append( | |
| TranscriptSegment( | |
| start=seg.start, | |
| end=seg.end, | |
| speaker=label, | |
| role=role, | |
| text=seg.text, | |
| ) | |
| ) | |
| processed_segments.sort(key=lambda x: x.start) | |
| # 8 : Predict emotion segments | |
| logger.info("Step 8: Predicting emo per segment ") | |
| processed_segments = cls._predict_emotion_segments(processed_segments, y, sr) | |
| # build emotion timeline | |
| emotion_timeline = cls.build_emotion_timeline(processed_segments) | |
| # detect emotion change | |
| emotion_changes = cls.detect_emotion_changes(emotion_timeline) | |
| processing_time = time.time() - t0 | |
| txt_content = cls._generate_txt( | |
| processed_segments, len(speakers), processing_time, duration, roles | |
| ) | |
| csv_content = cls._generate_csv(processed_segments) | |
| return ProcessingResult( | |
| segments=processed_segments, | |
| speaker_count=len(speakers), | |
| duration=duration, | |
| processing_time=processing_time, | |
| speakers=speakers, | |
| roles=roles, | |
| txt_content=txt_content, | |
| csv_content=csv_content, | |
| emotion_timeline=emotion_timeline, | |
| emotion_changes=emotion_changes, | |
| ) | |
| def _predict_emotion_segments( | |
| segments: List[TranscriptSegment], audio: np.ndarray, sr: int | |
| ): | |
| for seg in segments: | |
| # chα» predict emotion cho KH | |
| if seg.role != "KH": | |
| seg.emotion = None | |
| seg.icon = None | |
| continue | |
| emotion = EmotionService.predict_segment(audio, sr, seg.start, seg.end) | |
| seg.emotion = emotion | |
| seg.icon = EmotionService.meta.get(emotion, {}).get("emoji", "π") | |
| return segments | |
| def build_emotion_timeline(segments): | |
| timeline = [] | |
| for seg in segments: | |
| if seg.role != "KH": | |
| continue | |
| if not seg.emotion: | |
| continue | |
| if not seg.icon: | |
| continue | |
| icon = EmotionService.meta.get(seg.emotion, {}).get("emoji", "π") | |
| timeline.append( | |
| EmotionPoint(time=seg.start, emotion=seg.emotion, icon=icon) | |
| ) | |
| return timeline | |
| def detect_emotion_changes(timeline): | |
| changes = [] | |
| prev = None | |
| for point in timeline: | |
| if prev is not None and prev.emotion != point.emotion: | |
| icon_from = EmotionService.meta.get(prev.emotion, {}).get("emoji", "π") | |
| icon_to = EmotionService.meta.get(point.emotion, {}).get("emoji", "π") | |
| changes.append( | |
| EmotionChange( | |
| time=point.time, | |
| emotion_from=prev.emotion, | |
| emotion_to=point.emotion, | |
| icon_from=icon_from, | |
| icon_to=icon_to, | |
| ) | |
| ) | |
| prev = point | |
| return changes | |
| def _generate_txt( | |
| cls, | |
| segments: List[TranscriptSegment], | |
| speaker_count: int, | |
| processing_time: float, | |
| duration: float, | |
| roles: Dict[str, str], | |
| ) -> str: | |
| segments = sorted(segments, key=lambda s: s.start) | |
| speakers = [] | |
| for seg in segments: | |
| if seg.speaker and seg.speaker not in speakers: | |
| speakers.append(seg.speaker) | |
| lines = [ | |
| "# Transcription Result", | |
| f"# Duration: {format_timestamp(duration)}", | |
| f"# Speakers: {speaker_count}", | |
| f"# Roles: {roles}", | |
| f"# Processing time: {processing_time:.1f}s", | |
| "", | |
| ] | |
| icon_pool = ["π΅", "π’", "π‘", "π ", "π΄", "π£"] | |
| speaker_icons = { | |
| spk: icon_pool[i % len(icon_pool)] for i, spk in enumerate(speakers) | |
| } | |
| for seg in segments: | |
| ts = f"[{format_timestamp(seg.start)} β {format_timestamp(seg.end)}]" | |
| role = seg.role or "UNKNOWN" | |
| speaker_icon = speaker_icons.get(seg.speaker, "βͺ") | |
| emotion = seg.emotion or "" | |
| emotion_icon = ( | |
| EmotionService.meta.get(emotion, {}).get("emoji", "") if emotion else "" | |
| ) | |
| lines.append( | |
| f"{ts} {speaker_icon} [{seg.speaker}|{role}] {seg.text} {emotion_icon} {emotion}" | |
| ) | |
| return "\n".join(lines) | |
| def _generate_csv(cls, segments: List[TranscriptSegment]) -> str: | |
| import csv | |
| from io import StringIO | |
| output = StringIO() | |
| writer = csv.writer(output) | |
| writer.writerow(["start", "end", "speaker", "text", "emotion", "icon"]) | |
| for seg in segments: | |
| emotion = seg.emotion or "" | |
| icon = ( | |
| EmotionService.meta.get(emotion, {}).get("emoji", "") if emotion else "" | |
| ) | |
| writer.writerow( | |
| [ | |
| round(seg.start, 3), | |
| round(seg.end, 3), | |
| seg.speaker, | |
| seg.text, | |
| emotion, | |
| icon, | |
| ] | |
| ) | |
| return output.getvalue() | |