CleanSong commited on
Commit
7a79027
·
verified ·
1 Parent(s): 33b51a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -91
app.py CHANGED
@@ -29,7 +29,6 @@ def get_bad_words():
29
  except Exception as e:
30
  print(f"⚠️ Failed to fetch list: {e}")
31
 
32
- # fallback local list
33
  fallback = {"fuck", "shit", "bitch", "ass", "nigga", "nigger", "pussy", "cunt"}
34
  print(f"⚠️ Using fallback list ({len(fallback)} words).")
35
  return fallback
@@ -43,30 +42,7 @@ print(f"🚀 Loading LARGE Whisper model: {MODEL_NAME} ({COMPUTE_TYPE}) on {DEVI
43
  large_model = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)
44
  print("✅ Models ready!")
45
 
46
- # === HELPERS ===
47
- def merge_intervals(intervals, padding=0.15):
48
- """Merge overlapping intervals; also expand each interval by padding seconds."""
49
- if not intervals:
50
- return []
51
- # apply padding
52
- intervals = [(max(0, s - padding), e + padding) for s, e in intervals]
53
- intervals.sort(key=lambda x: x[0])
54
- merged = []
55
- cur_s, cur_e = intervals[0]
56
- for s, e in intervals[1:]:
57
- if s <= cur_e:
58
- cur_e = max(cur_e, e)
59
- else:
60
- merged.append((cur_s, cur_e))
61
- cur_s, cur_e = s, e
62
- merged.append((cur_s, cur_e))
63
- return merged
64
-
65
- def replace_range_in_list(lst, start_idx, end_idx, new_items):
66
- """Replace lst[start_idx:end_idx] with new_items (in-place)."""
67
- return lst[:start_idx] + new_items + lst[end_idx:]
68
-
69
- # === TRANSCRIBE FUNCTION (HYBRID) ===
70
  def transcribe(file_path):
71
  # --- Ensure proper audio format (mono, 16k) ---
72
  wav, sr = torchaudio.load(file_path)
@@ -78,8 +54,8 @@ def transcribe(file_path):
78
  fixed_path = "input_fixed.wav"
79
  torchaudio.save(fixed_path, wav, target_sr)
80
 
81
- # --- FAST PASS (cheap, detect possible explicit words) ---
82
- print("⚡ Running fast (cheap) pass to detect candidate explicit words…")
83
  fast_segments, fast_info = fast_model.transcribe(
84
  fixed_path,
85
  beam_size=1,
@@ -88,7 +64,7 @@ def transcribe(file_path):
88
  )
89
  sample_rate = getattr(fast_info, "sample_rate", target_sr)
90
 
91
- # Build initial transcript from fast pass
92
  transcript = []
93
  for seg in fast_segments:
94
  if hasattr(seg, "words") and seg.words:
@@ -103,7 +79,6 @@ def transcribe(file_path):
103
  "explicit": word_text.lower() in BAD_WORDS
104
  })
105
  else:
106
- # fallback: segment-level entry
107
  transcript.append({
108
  "text": seg.text,
109
  "start": float(seg.start),
@@ -111,40 +86,34 @@ def transcribe(file_path):
111
  "explicit": False
112
  })
113
 
114
- # --- Determine flagged intervals to re-run with large model ---
115
- flagged_intervals = [(w["start"], w["end"]) for w in transcript if w.get("explicit")]
116
- merged_intervals = merge_intervals(flagged_intervals, padding=0.15)
117
- print(f"🔎 Fast pass flagged {len(flagged_intervals)} words -> {len(merged_intervals)} merged intervals")
118
-
119
- # --- SECOND PASS (large model) on flagged words only ---
120
- if flagged_intervals:
121
  refined_entries = []
122
- for idx, w in enumerate([t for t in transcript if t.get("explicit")]):
 
123
  s, e = w["start"], w["end"]
124
- print(f"⏱️ Refining explicit word {idx+1}/{len(flagged_intervals)}: {s:.2f}s -> {e:.2f}s")
125
-
126
- # extract the single word chunk
127
  start_sample = int(max(0, s * sample_rate))
128
  end_sample = int(min(wav.shape[-1], e * sample_rate))
129
  num_frames = max(0, end_sample - start_sample)
130
  if num_frames == 0:
131
  continue
132
  chunk = wav[:, start_sample:end_sample]
133
-
134
- # write temp file
135
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
136
  temp_path = tmp.name
137
  torchaudio.save(temp_path, chunk, sample_rate)
138
-
139
- # transcribe chunk with large model
140
  segs, _ = large_model.transcribe(
141
  temp_path,
142
  beam_size=5,
143
  word_timestamps=True,
144
  vad_filter=True
145
  )
146
-
147
- # adjust chunk-relative timestamps to original
148
  for seg in segs:
149
  if hasattr(seg, "words") and seg.words:
150
  for word_obj in seg.words:
@@ -161,58 +130,24 @@ def transcribe(file_path):
161
  "end": float(seg.end) + s,
162
  "explicit": False
163
  })
164
-
165
  try:
166
  os.remove(temp_path)
167
  except Exception:
168
  pass
169
 
170
- # Merge refined words back into transcript
171
- final_transcript = []
172
- i = 0
173
- for t in transcript:
174
- if t.get("explicit"):
175
- # replace flagged word with refined version
176
- refined_word = refined_entries.pop(0)
177
- final_transcript.append(refined_word)
178
- else:
179
- final_transcript.append(t)
180
- transcript = final_transcript
181
-
182
-
183
- # --- Merge refined entries into the original fast transcript ---
184
- # For each merged interval, replace overlapping fast-pass items with refined items for that interval.
185
  final_transcript = []
186
- i = 0
187
- n = len(transcript)
188
- for interval in merged_intervals:
189
- int_s, int_e = interval
190
- # append all fast-pass items before this interval
191
- while i < n and transcript[i]["end"] <= int_s:
192
- final_transcript.append(transcript[i])
193
- i += 1
194
- # skip fast-pass items that overlap the interval
195
- skip_start = i
196
- while i < n and transcript[i]["start"] < int_e:
197
- i += 1
198
- skip_end = i
199
- # collect refined entries for this interval (those whose times fall into interval)
200
- refined_for_interval = [r for r in refined_entries if not (r["end"] <= int_s or r["start"] >= int_e)]
201
- # sort refined entries
202
- refined_for_interval.sort(key=lambda x: x.get("start", 0))
203
- # append refined entries instead of the skipped fast-pass items
204
- final_transcript.extend(refined_for_interval)
205
- # append any remaining fast-pass items after all intervals
206
- while i < n:
207
- final_transcript.append(transcript[i])
208
- i += 1
209
-
210
- # If no merged_intervals matched anything, fall back to initial transcript
211
- transcript = final_transcript if final_transcript else transcript
212
  else:
213
- print("✅ No flagged intervals — skipping large-model refinement.")
214
 
215
- # --- final housekeeping: if transcript is empty, build segment-level fallback from fast pass segments ---
216
  if not transcript:
217
  transcript = [{
218
  "text": seg.text,
@@ -231,7 +166,7 @@ iface = gr.Interface(
231
  inputs=gr.Audio(type="filepath", label="Upload Vocals"),
232
  outputs=gr.JSON(label="Transcript with Explicit Flags"),
233
  title="CleanSong AI — Whisper Transcriber (Hybrid Fast→Accurate)",
234
- description="Two-pass transcription: fast model to detect explicit words, large model to refine only flagged intervals."
235
  )
236
 
237
  if __name__ == "__main__":
 
29
  except Exception as e:
30
  print(f"⚠️ Failed to fetch list: {e}")
31
 
 
32
  fallback = {"fuck", "shit", "bitch", "ass", "nigga", "nigger", "pussy", "cunt"}
33
  print(f"⚠️ Using fallback list ({len(fallback)} words).")
34
  return fallback
 
42
  large_model = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)
43
  print("✅ Models ready!")
44
 
45
+ # === TRANSCRIBE FUNCTION (HYBRID WORD-LEVEL) ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def transcribe(file_path):
47
  # --- Ensure proper audio format (mono, 16k) ---
48
  wav, sr = torchaudio.load(file_path)
 
54
  fixed_path = "input_fixed.wav"
55
  torchaudio.save(fixed_path, wav, target_sr)
56
 
57
+ # --- FAST PASS (cheap) ---
58
+ print("⚡ Running fast pass to detect candidate explicit words…")
59
  fast_segments, fast_info = fast_model.transcribe(
60
  fixed_path,
61
  beam_size=1,
 
64
  )
65
  sample_rate = getattr(fast_info, "sample_rate", target_sr)
66
 
67
+ # Build initial transcript
68
  transcript = []
69
  for seg in fast_segments:
70
  if hasattr(seg, "words") and seg.words:
 
79
  "explicit": word_text.lower() in BAD_WORDS
80
  })
81
  else:
 
82
  transcript.append({
83
  "text": seg.text,
84
  "start": float(seg.start),
 
86
  "explicit": False
87
  })
88
 
89
+ # --- SECOND PASS: large model on explicit words only ---
90
+ flagged_words = [t for t in transcript if t.get("explicit")]
91
+ if flagged_words:
92
+ print(f"🔎 Fast pass flagged {len(flagged_words)} explicit words refining with large model…")
 
 
 
93
  refined_entries = []
94
+
95
+ for idx, w in enumerate(flagged_words):
96
  s, e = w["start"], w["end"]
97
+ print(f"⏱️ Refining word {idx+1}/{len(flagged_words)}: {s:.2f}s -> {e:.2f}s")
98
+
 
99
  start_sample = int(max(0, s * sample_rate))
100
  end_sample = int(min(wav.shape[-1], e * sample_rate))
101
  num_frames = max(0, end_sample - start_sample)
102
  if num_frames == 0:
103
  continue
104
  chunk = wav[:, start_sample:end_sample]
105
+
 
106
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
107
  temp_path = tmp.name
108
  torchaudio.save(temp_path, chunk, sample_rate)
109
+
 
110
  segs, _ = large_model.transcribe(
111
  temp_path,
112
  beam_size=5,
113
  word_timestamps=True,
114
  vad_filter=True
115
  )
116
+
 
117
  for seg in segs:
118
  if hasattr(seg, "words") and seg.words:
119
  for word_obj in seg.words:
 
130
  "end": float(seg.end) + s,
131
  "explicit": False
132
  })
133
+
134
  try:
135
  os.remove(temp_path)
136
  except Exception:
137
  pass
138
 
139
+ # Merge refined words back into transcript
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  final_transcript = []
141
+ for t in transcript:
142
+ if t.get("explicit") and refined_entries:
143
+ final_transcript.append(refined_entries.pop(0))
144
+ else:
145
+ final_transcript.append(t)
146
+ transcript = final_transcript
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  else:
148
+ print("✅ No flagged words — skipping large-model refinement.")
149
 
150
+ # --- fallback if transcript empty ---
151
  if not transcript:
152
  transcript = [{
153
  "text": seg.text,
 
166
  inputs=gr.Audio(type="filepath", label="Upload Vocals"),
167
  outputs=gr.JSON(label="Transcript with Explicit Flags"),
168
  title="CleanSong AI — Whisper Transcriber (Hybrid Fast→Accurate)",
169
+ description="Two-pass transcription: fast model to detect explicit words, large model to refine only flagged words."
170
  )
171
 
172
  if __name__ == "__main__":