PoC_ASR_v6_dev / app /services /processor.py
vyluong's picture
Update app/services/processor.py
e58b51a verified
import logging
import subprocess
import time
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict, Counter
import numpy as np
import librosa
from app.core.config import get_settings
from app.services.transcription import TranscriptionService
from app.services.alignment import AlignmentService
from app.services.transcription import WordTimestamp
from app.services.emo import EmotionService
from app.services.diarization import (
DiarizationService,
SpeakerSegment,
DiarizationResult,
)
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class TranscriptSegment:
"""A transcribed segment with speaker info."""
start: float
end: float
speaker: str
role: Optional[str]
text: str
emotion: Optional[str] = None
icon: Optional[str] = None
@dataclass
class EmotionPoint:
time: float
emotion: str
icon: Optional[str]
@dataclass
class EmotionChange:
time: float
emotion_from: str
emotion_to: str
icon_from: Optional[str] = None
icon_to: Optional[str] = None
@dataclass
class ProcessingResult:
"""Result of audio processing."""
segments: List[TranscriptSegment]
speaker_count: int
duration: float
processing_time: float
speakers: List[str]
roles: Dict[str, str]
txt_content: str = ""
csv_content: str = ""
emotion_timeline: List[EmotionPoint] = None
emotion_changes: List[EmotionChange] = None
def normalize_asr_result(result: dict):
words = []
for w in result.get("words", []):
word = (
w.get("word", "")
.strip()
)
if not word:
continue
words.append(
{
"word": word,
"start": float(w["start"]),
"end": float(w["end"]),
"speaker": w.get("speaker"),
"confidence": float(
w.get("confidence", 1.0)
),
}
)
text = result.get("text", "").strip()
return text, words
def guess_speaker_by_overlap(start, end, diar_segments):
best_spk = None
best_overlap = 0.0
for seg in diar_segments:
overlap = max(0.0, min(end, seg.end) - max(start, seg.start))
if overlap > best_overlap:
best_overlap = overlap
best_spk = seg.speaker
return best_spk or diar_segments[0].speaker
def convert_audio_to_wav(audio_path: Path) -> Path:
"""Convert any audio to WAV 16kHz Mono using ffmpeg."""
output_path = audio_path.parent / f"{audio_path.stem}_processed.wav"
if output_path.exists():
output_path.unlink()
command = [
"ffmpeg",
"-i",
str(audio_path),
"-ar",
"16000",
"-ac",
"1",
"-y",
str(output_path),
]
try:
subprocess.run(
command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
logger.info(f"Converted audio to WAV: {output_path}")
return output_path
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg conversion failed: {e}")
return audio_path
def format_timestamp(seconds: float) -> str:
m = int(seconds // 60)
s = seconds % 60
return f"{m:02d}:{s:06.3f}"
def merge_consecutive_segments(
segments: List[SpeakerSegment],
max_gap: float = 0.80,
max_overlap: float = 0.15,
) -> List[SpeakerSegment]:
if not segments:
return []
segments = sorted(
segments,
key=lambda x: x.start
)
merged = [segments[0]]
for seg in segments[1:]:
prev = merged[-1]
gap = seg.start - prev.end
if (
seg.speaker == prev.speaker
and gap >= -max_overlap
and gap <= max_gap
):
prev.end = max(
prev.end,
seg.end
)
else:
merged.append(seg)
return merged
def overlap_prefix(a: str, b: str, n: int = 12) -> bool:
if not a or not b:
return False
a = a.strip().lower()
b = b.strip().lower()
return a[:n] in b or b[:n] in a
class Processor:
@classmethod
async def process_audio(
cls,
audio_path: Path,
model_name: str = "PhoWhisper Lora Finetuned",
language="vi",
merge_segments: bool = True
) -> ProcessingResult:
import asyncio
t0 = time.time()
EmotionService.preload_model()
# 1: Convert to WAV
logger.info("Step 1: Converting audio to WAV 16kHz...")
wav_path = await asyncio.get_event_loop().run_in_executor(
None, convert_audio_to_wav, audio_path
)
# 2: Load audio
y, sr = librosa.load(wav_path, sr=16000, mono=True)
if y.size == 0:
raise ValueError("Empty audio")
duration = len(y) / sr
# 3: Diarization
logger.info("Step 3: Running diarization...")
diarization: DiarizationResult = await DiarizationService.diarize_async(
wav_path
)
diarization_segments = diarization.segments or []
if not diarization_segments:
diarization_segments = [SpeakerSegment(0.0, duration, "SPEAKER_0")]
speakers = ["SPEAKER_0"]
roles = {"SPEAKER_0": "KH"}
diarization_segments.sort(key=lambda x: x.start)
if merge_segments and diarization_segments:
logger.info("Step 4: Merging consecutive segments...")
diarization_segments = merge_consecutive_segments(diarization_segments)
# 4. Normalize speakers
raw_speakers = sorted({seg.speaker for seg in diarization_segments})
speaker_map = {spk: f"Speaker {i+1}" for i, spk in enumerate(raw_speakers)}
speakers = list(speaker_map.values())
raw_roles = diarization.roles or {}
roles = {}
for raw_spk, label in speaker_map.items():
roles[label] = raw_roles.get(raw_spk, "KH")
logger.info(f"roles(mapped) = {roles}")
# 7: Transcribe segments after diarization
logger.info("Step 7: Running ASR with external VAD batch...")
asr_result = await TranscriptionService.transcribe_with_words_async(
audio_array=y,
model_name=model_name,
language=language,
vad_options=False
)
text, raw_words = normalize_asr_result(asr_result)
if not raw_words:
processed_segments = [
TranscriptSegment(
start=0.0,
end=duration,
speaker=speakers[0],
role=roles[speakers[0]],
text="(No speech detected)",
)
]
else:
# ===== CONVERT TO WordTimestamp =====
word_objs: List[WordTimestamp] = []
for w in raw_words:
spk = w.get("speaker")
if spk is None:
spk = guess_speaker_by_overlap(
w["start"], w["end"], diarization_segments
)
word_objs.append(
WordTimestamp(
word=w["word"],
start=w["start"],
end=w["end"],
speaker=spk,
confidence=w.get("confidence", 1.0)
)
)
word_objs.sort(key=lambda x: x.start)
# ===== ALIGNMENT =====
aligned_segments = AlignmentService.align_precision(
word_objs,
diarization_segments
)
processed_segments = []
if not aligned_segments:
vote = [w.speaker for w in word_objs if w.speaker]
if vote:
raw_spk = Counter(vote).most_common(1)[0][0]
else:
raw_spk = diarization_segments[0].speaker
label = speaker_map.get(raw_spk, "Speaker 1")
processed_segments.append(
TranscriptSegment(0, duration, label, roles[label], text)
)
else:
for seg in aligned_segments:
raw_spk = seg.speaker
label = speaker_map.get(raw_spk, "Speaker 1")
role = roles.get(label, "KH")
processed_segments.append(
TranscriptSegment(
start=seg.start,
end=seg.end,
speaker=label,
role=role,
text=seg.text,
)
)
processed_segments.sort(key=lambda x: x.start)
# 8 : Predict emotion segments
logger.info("Step 8: Predicting emo per segment ")
processed_segments = cls._predict_emotion_segments(processed_segments, y, sr)
# build emotion timeline
emotion_timeline = cls.build_emotion_timeline(processed_segments)
# detect emotion change
emotion_changes = cls.detect_emotion_changes(emotion_timeline)
processing_time = time.time() - t0
txt_content = cls._generate_txt(
processed_segments, len(speakers), processing_time, duration, roles
)
csv_content = cls._generate_csv(processed_segments)
return ProcessingResult(
segments=processed_segments,
speaker_count=len(speakers),
duration=duration,
processing_time=processing_time,
speakers=speakers,
roles=roles,
txt_content=txt_content,
csv_content=csv_content,
emotion_timeline=emotion_timeline,
emotion_changes=emotion_changes,
)
@staticmethod
def _predict_emotion_segments(
segments: List[TranscriptSegment], audio: np.ndarray, sr: int
):
for seg in segments:
# chỉ predict emotion cho KH
if seg.role != "KH":
seg.emotion = None
seg.icon = None
continue
emotion = EmotionService.predict_segment(audio, sr, seg.start, seg.end)
seg.emotion = emotion
seg.icon = EmotionService.meta.get(emotion, {}).get("emoji", "πŸ™‚")
return segments
@staticmethod
def build_emotion_timeline(segments):
timeline = []
for seg in segments:
if seg.role != "KH":
continue
if not seg.emotion:
continue
if not seg.icon:
continue
icon = EmotionService.meta.get(seg.emotion, {}).get("emoji", "πŸ™‚")
timeline.append(
EmotionPoint(time=seg.start, emotion=seg.emotion, icon=icon)
)
return timeline
@staticmethod
def detect_emotion_changes(timeline):
changes = []
prev = None
for point in timeline:
if prev is not None and prev.emotion != point.emotion:
icon_from = EmotionService.meta.get(prev.emotion, {}).get("emoji", "πŸ™‚")
icon_to = EmotionService.meta.get(point.emotion, {}).get("emoji", "πŸ™‚")
changes.append(
EmotionChange(
time=point.time,
emotion_from=prev.emotion,
emotion_to=point.emotion,
icon_from=icon_from,
icon_to=icon_to,
)
)
prev = point
return changes
@classmethod
def _generate_txt(
cls,
segments: List[TranscriptSegment],
speaker_count: int,
processing_time: float,
duration: float,
roles: Dict[str, str],
) -> str:
segments = sorted(segments, key=lambda s: s.start)
speakers = []
for seg in segments:
if seg.speaker and seg.speaker not in speakers:
speakers.append(seg.speaker)
lines = [
"# Transcription Result",
f"# Duration: {format_timestamp(duration)}",
f"# Speakers: {speaker_count}",
f"# Roles: {roles}",
f"# Processing time: {processing_time:.1f}s",
"",
]
icon_pool = ["πŸ”΅", "🟒", "🟑", "🟠", "πŸ”΄", "🟣"]
speaker_icons = {
spk: icon_pool[i % len(icon_pool)] for i, spk in enumerate(speakers)
}
for seg in segments:
ts = f"[{format_timestamp(seg.start)} β†’ {format_timestamp(seg.end)}]"
role = seg.role or "UNKNOWN"
speaker_icon = speaker_icons.get(seg.speaker, "βšͺ")
emotion = seg.emotion or ""
emotion_icon = (
EmotionService.meta.get(emotion, {}).get("emoji", "") if emotion else ""
)
lines.append(
f"{ts} {speaker_icon} [{seg.speaker}|{role}] {seg.text} {emotion_icon} {emotion}"
)
return "\n".join(lines)
@classmethod
def _generate_csv(cls, segments: List[TranscriptSegment]) -> str:
import csv
from io import StringIO
output = StringIO()
writer = csv.writer(output)
writer.writerow(["start", "end", "speaker", "text", "emotion", "icon"])
for seg in segments:
emotion = seg.emotion or ""
icon = (
EmotionService.meta.get(emotion, {}).get("emoji", "") if emotion else ""
)
writer.writerow(
[
round(seg.start, 3),
round(seg.end, 3),
seg.speaker,
seg.text,
emotion,
icon,
]
)
return output.getvalue()