CleanSong commited on
Commit
6ab727a
·
verified ·
1 Parent(s): be47ae1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -79
app.py CHANGED
@@ -41,18 +41,17 @@ print("✅ Models ready!")
41
 
42
  # === TRANSCRIBE FUNCTION (HYBRID WORD-LEVEL) ===
43
  def transcribe(file_path):
44
- # --- Ensure proper audio format (mono, 16k) ---
45
  wav, sr = torchaudio.load(file_path)
46
  target_sr = 16000
47
  if sr != target_sr:
48
  wav = torchaudio.functional.resample(wav, sr, target_sr)
49
  if wav.shape[0] > 1:
50
- wav = wav.mean(dim=0, keepdim=True) # mono
51
  fixed_path = "input_fixed.wav"
52
  torchaudio.save(fixed_path, wav, target_sr)
53
 
54
- # --- FAST PASS (cheap) ---
55
- print("⚡ Running fast pass to detect candidate explicit words…")
56
  fast_segments, fast_info = fast_model.transcribe(
57
  fixed_path,
58
  beam_size=1,
@@ -61,101 +60,86 @@ def transcribe(file_path):
61
  )
62
  sample_rate = getattr(fast_info, "sample_rate", target_sr)
63
 
64
- # Build initial transcript
65
  transcript = []
66
  for seg in fast_segments:
67
  if hasattr(seg, "words") and seg.words:
68
  for w in seg.words:
69
  word_text = w.word.strip()
70
- start = float(w.start)
71
- end = float(w.end)
72
  transcript.append({
73
  "word": word_text,
74
- "start": start,
75
- "end": end,
76
- "explicit": word_text.lower() in BAD_WORDS
 
77
  })
78
  else:
79
  transcript.append({
80
  "text": seg.text,
81
  "start": float(seg.start),
82
  "end": float(seg.end),
83
- "explicit": False
 
84
  })
85
 
86
- # --- SECOND PASS: large model on explicit words only ---
87
  flagged_words = [t for t in transcript if t.get("explicit")]
88
- if flagged_words:
89
- print(f"🔎 Fast pass flagged {len(flagged_words)} explicit words — refining with large model…")
90
- refined_entries = []
91
-
92
- for idx, w in enumerate(flagged_words):
93
- s, e = w["start"], w["end"]
94
- print(f"⏱️ Refining word {idx+1}/{len(flagged_words)}: {s:.2f}s -> {e:.2f}s")
95
-
96
- start_sample = int(max(0, s * sample_rate))
97
- end_sample = int(min(wav.shape[-1], e * sample_rate))
98
- num_frames = max(0, end_sample - start_sample)
99
- if num_frames == 0:
100
- continue
101
- chunk = wav[:, start_sample:end_sample]
102
-
103
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
104
- temp_path = tmp.name
105
- torchaudio.save(temp_path, chunk, sample_rate)
106
-
107
- segs, _ = large_model.transcribe(
108
- temp_path,
109
- beam_size=5,
110
- word_timestamps=True,
111
- vad_filter=True
112
- )
113
-
114
- for seg in segs:
115
- if hasattr(seg, "words") and seg.words:
116
- for word_obj in seg.words:
117
- refined_entries.append({
118
- "word": word_obj.word.strip(),
119
- "start": float(word_obj.start) + s,
120
- "end": float(word_obj.end) + s,
121
- "explicit": word_obj.word.strip().lower() in BAD_WORDS
122
- })
123
- else:
124
  refined_entries.append({
125
- "text": seg.text,
126
- "start": float(seg.start) + s,
127
- "end": float(seg.end) + s,
128
- "explicit": False
 
129
  })
 
 
 
 
 
 
 
 
130
 
131
- try:
132
- os.remove(temp_path)
133
- except Exception:
134
- pass
 
 
 
 
 
 
 
 
 
135
 
136
- # Merge refined words back into transcript
137
- final_transcript = []
138
- for t in transcript:
139
- if t.get("explicit") and refined_entries:
140
- final_transcript.append(refined_entries.pop(0))
141
- else:
142
- final_transcript.append(t)
143
- transcript = final_transcript
144
- else:
145
- print("✅ No flagged words — skipping large-model refinement.")
146
-
147
- # --- fallback if transcript empty ---
148
- if not transcript:
149
- transcript = [{
150
- "text": seg.text,
151
- "start": float(seg.start),
152
- "end": float(seg.end),
153
- "explicit": False
154
- } for seg in fast_segments]
155
-
156
- print(f"✅ Final transcript contains {len(transcript)} entries "
157
- f"({sum(1 for w in transcript if w.get('explicit'))} explicit). {transcript[:200]}")
158
- return transcript
159
 
160
  # === GRADIO INTERFACE ===
161
  iface = gr.Interface(
 
41
 
42
  # === TRANSCRIBE FUNCTION (HYBRID WORD-LEVEL) ===
43
  def transcribe(file_path):
44
+ # --- Ensure proper audio format ---
45
  wav, sr = torchaudio.load(file_path)
46
  target_sr = 16000
47
  if sr != target_sr:
48
  wav = torchaudio.functional.resample(wav, sr, target_sr)
49
  if wav.shape[0] > 1:
50
+ wav = wav.mean(dim=0, keepdim=True)
51
  fixed_path = "input_fixed.wav"
52
  torchaudio.save(fixed_path, wav, target_sr)
53
 
54
+ # --- FAST PASS ---
 
55
  fast_segments, fast_info = fast_model.transcribe(
56
  fixed_path,
57
  beam_size=1,
 
60
  )
61
  sample_rate = getattr(fast_info, "sample_rate", target_sr)
62
 
63
+ # Initial transcript with explicit flags
64
  transcript = []
65
  for seg in fast_segments:
66
  if hasattr(seg, "words") and seg.words:
67
  for w in seg.words:
68
  word_text = w.word.strip()
69
+ is_explicit = word_text.lower() in BAD_WORDS
 
70
  transcript.append({
71
  "word": word_text,
72
+ "start": float(w.start),
73
+ "end": float(w.end),
74
+ "explicit": is_explicit, # 🔥 keep fast-pass explicit flag
75
+ "explicit_fast": is_explicit # permanent record of fast-pass
76
  })
77
  else:
78
  transcript.append({
79
  "text": seg.text,
80
  "start": float(seg.start),
81
  "end": float(seg.end),
82
+ "explicit": False,
83
+ "explicit_fast": False
84
  })
85
 
86
+ # --- SECOND PASS: refine explicit words only ---
87
  flagged_words = [t for t in transcript if t.get("explicit")]
88
+ refined_entries = []
89
+
90
+ for idx, w in enumerate(flagged_words):
91
+ s, e = w["start"], w["end"]
92
+ start_sample = int(max(0, s * sample_rate))
93
+ end_sample = int(min(wav.shape[-1], e * sample_rate))
94
+ chunk = wav[:, start_sample:end_sample]
95
+
96
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
97
+ temp_path = tmp.name
98
+ torchaudio.save(temp_path, chunk, sample_rate)
99
+
100
+ segs, _ = large_model.transcribe(
101
+ temp_path,
102
+ beam_size=5,
103
+ word_timestamps=True,
104
+ vad_filter=True
105
+ )
106
+
107
+ for seg in segs:
108
+ if hasattr(seg, "words") and seg.words:
109
+ for word_obj in seg.words:
110
+ # 🔥 Keep explicit from fast-pass instead of trusting large model
111
+ orig_explicit = w.get("explicit_fast", False)
 
 
 
 
 
 
 
 
 
 
 
 
112
  refined_entries.append({
113
+ "word": word_obj.word.strip(),
114
+ "start": float(word_obj.start) + s,
115
+ "end": float(word_obj.end) + s,
116
+ "explicit": orig_explicit, # preserve explicit
117
+ "explicit_fast": orig_explicit
118
  })
119
+ else:
120
+ refined_entries.append({
121
+ "text": seg.text,
122
+ "start": float(seg.start) + s,
123
+ "end": float(seg.end) + s,
124
+ "explicit": w.get("explicit_fast", False),
125
+ "explicit_fast": w.get("explicit_fast", False)
126
+ })
127
 
128
+ try:
129
+ os.remove(temp_path)
130
+ except Exception:
131
+ pass
132
+
133
+ # Merge refined words back, keeping fast-pass explicit
134
+ final_transcript = []
135
+ refined_iter = iter(refined_entries)
136
+ for t in transcript:
137
+ if t.get("explicit"):
138
+ final_transcript.append(next(refined_iter))
139
+ else:
140
+ final_transcript.append(t)
141
 
142
+ return final_transcript
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  # === GRADIO INTERFACE ===
145
  iface = gr.Interface(