mimoha commited on
Commit
61b8eb3
·
verified ·
1 Parent(s): 3a2a9b2

Upload whisper_asr.py

Browse files
Files changed (1) hide show
  1. whisper_asr.py +121 -0
whisper_asr.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # whisper_asr.py
2
+ import os, re, json, math, tempfile, traceback
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import soundfile as sf
7
+ import textdistance
8
+
9
+ # هذه المتغيرات مطلوبة في هذا الملف، وستُستورد إلى ملف المعالجة اللاحقة
10
+ # لتجنب تكرار التعريف.
11
+ FORCE_WHISPER_NAME = "large-v3"
12
+ FORCE_COMPUTE_TYPE = "int8"
13
+ FORCE_USE_MARBERT = True
14
+
15
+ # خيارات تفريغ ثابتة لتقليل الفروقات
16
+ ASR_OPTS = dict(
17
+ word_timestamps=True,
18
+ vad_filter=True,
19
+ vad_parameters={"min_silence_duration_ms": 200},
20
+ beam_size=5,
21
+ best_of=5,
22
+ temperature=0.0,
23
+ )
24
+
25
+ # =========================
26
+ # Device
27
+ # =========================
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ print(f"[INIT] DEVICE={DEVICE}", flush=True)
30
+
31
+ # =========================
32
+ # Lazy models (Whisper only)
33
+ # =========================
34
+ _WHISPER = None
35
+
36
+ def load_whisper_model(
37
+ whisper_name=FORCE_WHISPER_NAME,
38
+ whisper_compute=FORCE_COMPUTE_TYPE,
39
+ ):
40
+ """Load Whisper model once; forced config respected even on CPU."""
41
+ global _WHISPER
42
+ from faster_whisper import WhisperModel
43
+
44
+ if _WHISPER is None:
45
+ _WHISPER = WhisperModel(whisper_name, device=("cuda" if DEVICE=="cuda" else "cpu"),
46
+ compute_type=whisper_compute)
47
+ print(f"[LOAD] Whisper: {whisper_name} (compute={whisper_compute})", flush=True)
48
+ return _WHISPER
49
+
50
+ # =========================
51
+ # Faster-Whisper helpers
52
+ # =========================
53
+ def normalize_ar_orth(text: str) -> str:
54
+ # تطبيع عام للمحاذاة
55
+ text = re.sub(r"[ًٌٍَُِّْـ]", "", text)
56
+ text = re.sub(r"[“”\"',:؛؟.!()\[\]{}،\-–—_]", " ", text)
57
+ text = re.sub(r"\s+", " ", text).strip()
58
+ return text
59
+
60
+ def clean_ar_token(t: str) -> str:
61
+ t = t.strip()
62
+ t = re.sub(r'^[^\w\u0600-\u06FF]+|[^\w\u0600-\u06FF]+$', '', t)
63
+ t = normalize_ar_orth(t)
64
+ return t
65
+
66
+ def extract_word_conf_table(segments):
67
+ rows = []
68
+ for seg in segments:
69
+ for w in (seg.words or []):
70
+ rows.append({
71
+ "seg_start": float(seg.start),
72
+ "seg_end": float(seg.end),
73
+ "word_start": float(w.start),
74
+ "word_end": float(w.end),
75
+ "word": clean_ar_token(w.word),
76
+ "prob": float(w.probability),
77
+ })
78
+ return pd.DataFrame(rows)
79
+
80
+ def build_asr_token_conf(df_words: pd.DataFrame, hyp_tokens: list):
81
+ toks_probs, toks_durs = [], []
82
+ for _, row in df_words.iterrows():
83
+ prob = row["prob"]
84
+ dur = (row["word_end"] - row["word_start"]) * 1000.0
85
+ toks_probs.append(prob)
86
+ toks_durs.append(dur)
87
+
88
+ L = len(hyp_tokens)
89
+ if len(toks_probs) >= L:
90
+ toks_probs = toks_probs[:L]
91
+ toks_durs = toks_durs[:L]
92
+ else:
93
+ pad = L - len(toks_probs)
94
+ toks_probs += [None]*pad
95
+ toks_durs += [None]*pad
96
+
97
+ arr = np.array([p for p in toks_probs if p is not None])
98
+ if arr.size:
99
+ low_t = float(np.quantile(arr, 0.15))
100
+ high_t = float(np.quantile(arr, 0.70))
101
+ else:
102
+ low_t, high_t = 0.5, 0.85
103
+
104
+ asr_token_conf = {i: {"prob": toks_probs[i], "duration_ms": toks_durs[i]} for i in range(L)}
105
+ return asr_token_conf, low_t, high_t
106
+
107
+ # =========================
108
+ # Audio helper
109
+ # =========================
110
+ def ensure_audio_path(audio):
111
+ if isinstance(audio, str):
112
+ if not os.path.exists(audio):
113
+ raise FileNotFoundError(f"Audio path not found: {audio}")
114
+ return audio
115
+ if isinstance(audio, tuple) and len(audio) == 2:
116
+ data, sr = audio
117
+ if isinstance(data, np.ndarray):
118
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
119
+ sf.write(tmp.name, data, sr)
120
+ return tmp.name
121
+ raise ValueError("Unsupported audio input format")