Aynursusuz commited on
Commit
226ea00
Β·
verified Β·
1 Parent(s): 7b2f634

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +296 -0
  2. packages.txt +1 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from transformers import ASTForAudioClassification, AutoFeatureExtractor
6
+ from pydub import AudioSegment
7
+ import tempfile
8
+ import logging
9
+ from datetime import datetime
10
+ from typing import Tuple, List, Optional
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class MusicRemover:
17
+
18
+ def __init__(self, model_name: str = "Vyvo-Research/AST-Music-Classifier-1K"):
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ logger.info(f"Initializing on {self.device}")
21
+
22
+ self.model = ASTForAudioClassification.from_pretrained(model_name).to(self.device)
23
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
24
+ self.model.eval()
25
+
26
+ if self.device.type == "cuda":
27
+ self.model = self.model.half()
28
+ torch.backends.cudnn.benchmark = True
29
+
30
+ logger.info("Model loaded successfully")
31
+
32
+ def load_audio(self, file_path: str):
33
+ audio = AudioSegment.from_file(file_path)
34
+ audio = audio.set_channels(1)
35
+
36
+ sample_rate = self.feature_extractor.sampling_rate
37
+ audio = audio.set_frame_rate(sample_rate)
38
+
39
+ samples = np.array(audio.get_array_of_samples()).astype(np.float32)
40
+ samples = samples / np.iinfo(audio.array_type).max
41
+
42
+ return samples, sample_rate, audio
43
+
44
+ @torch.no_grad()
45
+ def detect_music_segments(self, audio_array, sample_rate, threshold, window_size, hop_size):
46
+ window_samples = int(window_size * sample_rate)
47
+ hop_samples = int(hop_size * sample_rate)
48
+
49
+ music_segments = []
50
+ total_samples = len(audio_array)
51
+ total_duration = total_samples / sample_rate
52
+
53
+ logger.info(f"Audio: {total_duration:.1f}s, Window: {window_size}s, Hop: {hop_size}s")
54
+ logger.info(f"Total samples: {total_samples}, Window samples: {window_samples}, Hop samples: {hop_samples}")
55
+
56
+ segment_count = 0
57
+ last_was_music = False
58
+
59
+ for start in range(0, total_samples, hop_samples):
60
+ end = min(start + window_samples, total_samples)
61
+ segment = audio_array[start:end]
62
+ segment_duration = len(segment) / sample_rate
63
+
64
+ # Γ‡ok kΔ±sa segmentleri atla (1 saniyeden az)
65
+ if len(segment) < sample_rate:
66
+ logger.info(f"Skipping final segment (too short): {segment_duration:.2f}s")
67
+ continue
68
+
69
+ segment_count += 1
70
+ start_sec = start / sample_rate
71
+ end_sec = end / sample_rate
72
+
73
+ # KΔ±sa segmentleri padding ile doldur
74
+ needs_padding = len(segment) < window_samples
75
+ if needs_padding:
76
+ segment = np.pad(segment, (0, window_samples - len(segment)), mode='constant')
77
+ logger.info(f"Processing segment {segment_count}: {start_sec:.1f}s - {end_sec:.1f}s (padded)")
78
+ else:
79
+ logger.info(f"Processing segment {segment_count}: {start_sec:.1f}s - {end_sec:.1f}s")
80
+
81
+ inputs = self.feature_extractor(
82
+ segment,
83
+ sampling_rate=sample_rate,
84
+ return_tensors="pt",
85
+ padding="max_length",
86
+ truncation=True,
87
+ max_length=1024
88
+ )
89
+
90
+ if self.device.type == "cuda":
91
+ inputs = {k: v.to(self.device).half() for k, v in inputs.items()}
92
+ else:
93
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
94
+
95
+ outputs = self.model(**inputs)
96
+ probs = torch.softmax(outputs.logits, dim=-1)
97
+
98
+ # Label'larΔ± al
99
+ labels = self.model.config.id2label
100
+
101
+ # En yΓΌksek skorlu label'Δ± bul (argmax)
102
+ pred_idx = torch.argmax(probs[0]).item()
103
+ pred_label = labels.get(pred_idx, f'idx{pred_idx}')
104
+ pred_score = probs[0][pred_idx].item()
105
+
106
+ logger.info(f" -> Prediction: {pred_label} ({pred_score:.2%})")
107
+
108
+ # Eğer prediction "music" ise ve confidence yeterli ise müzik olarak işaretle
109
+ is_music = 'music' in pred_label.lower()
110
+
111
+ # Belirsiz sonuΓ§ kontrolΓΌ (40-60% arasΔ±)
112
+ is_uncertain = 0.40 <= pred_score <= 0.60
113
+
114
+ if is_uncertain and needs_padding:
115
+ # KΔ±sa segment + belirsiz sonuΓ§ = ΓΆnceki sonucu kullan
116
+ if last_was_music:
117
+ start_ms = int(start_sec * 1000)
118
+ end_ms = int(end_sec * 1000)
119
+ music_segments.append((start_ms, end_ms, pred_score))
120
+ logger.info(f" -> MUSIC (uncertain {pred_score:.0%}, using previous)")
121
+ else:
122
+ logger.info(f" -> SPEECH (uncertain {pred_score:.0%}, using previous)")
123
+ elif is_music and pred_score >= threshold:
124
+ start_ms = int(start_sec * 1000)
125
+ end_ms = int(end_sec * 1000)
126
+ music_segments.append((start_ms, end_ms, pred_score))
127
+ last_was_music = True
128
+ logger.info(f" -> MUSIC DETECTED!")
129
+ else:
130
+ last_was_music = False
131
+ if is_music:
132
+ logger.info(f" -> Low confidence music ({pred_score:.1%} < {threshold:.0%}), treating as speech")
133
+
134
+ logger.info(f"Processed {segment_count} segments, found {len(music_segments)} music segments")
135
+ return music_segments
136
+
137
+ def merge_overlapping_segments(self, segments):
138
+ if not segments:
139
+ return []
140
+
141
+ segments = sorted(segments, key=lambda x: x[0])
142
+ merged = [segments[0]]
143
+
144
+ for current in segments[1:]:
145
+ last = merged[-1]
146
+
147
+ if current[0] <= last[1]:
148
+ merged[-1] = (
149
+ last[0],
150
+ max(last[1], current[1]),
151
+ max(last[2], current[2])
152
+ )
153
+ else:
154
+ merged.append(current)
155
+
156
+ return merged
157
+
158
+ def remove_music(self, audio, music_segments):
159
+ if not music_segments:
160
+ return audio, [(0, len(audio)/1000)]
161
+
162
+ clean_segments = []
163
+ kept_ranges = []
164
+ last_end = 0
165
+
166
+ for start_ms, end_ms, _ in music_segments:
167
+ if start_ms > last_end:
168
+ clean_segments.append(audio[last_end:start_ms])
169
+ kept_ranges.append((last_end/1000, start_ms/1000))
170
+ last_end = end_ms
171
+
172
+ if last_end < len(audio):
173
+ clean_segments.append(audio[last_end:])
174
+ kept_ranges.append((last_end/1000, len(audio)/1000))
175
+
176
+ if not clean_segments:
177
+ return AudioSegment.silent(duration=0), []
178
+
179
+ return sum(clean_segments), kept_ranges
180
+
181
+ def process(self, input_file, output_format="wav", threshold=0.50, window_size=5.0, hop_size=5.0, progress=None):
182
+ try:
183
+ if progress:
184
+ progress(0, desc="Loading audio...")
185
+
186
+ audio_array, sample_rate, audio = self.load_audio(input_file)
187
+ original_duration = len(audio) / 1000
188
+
189
+ if progress:
190
+ progress(0.2, desc="Detecting music...")
191
+
192
+ music_segments = self.detect_music_segments(
193
+ audio_array, sample_rate, threshold, window_size, hop_size
194
+ )
195
+
196
+ if progress:
197
+ progress(0.6, desc="Processing...")
198
+
199
+ music_segments = self.merge_overlapping_segments(music_segments)
200
+
201
+ if progress:
202
+ progress(0.8, desc="Removing music...")
203
+
204
+ clean_audio, kept_ranges = self.remove_music(audio, music_segments)
205
+ clean_duration = len(clean_audio) / 1000
206
+ removed_duration = original_duration - clean_duration
207
+
208
+ if progress:
209
+ progress(0.9, desc="Saving...")
210
+
211
+ format_settings = {
212
+ "wav": {"format": "wav"},
213
+ "mp3": {"format": "mp3", "bitrate": "192k"},
214
+ "ogg": {"format": "ogg", "codec": "libvorbis"}
215
+ }
216
+ settings = format_settings.get(output_format, format_settings["wav"])
217
+
218
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{output_format}") as tmp_file:
219
+ clean_audio.export(tmp_file.name, **settings)
220
+ output_path = tmp_file.name
221
+
222
+ if progress:
223
+ progress(1.0, desc="Complete!")
224
+
225
+ segments_detail = ""
226
+ if music_segments:
227
+ segments_detail = "\n### 🎡 Detected Music Segments:\n| # | Start | End | Confidence |\n|---|-------|-----|------------|\n"
228
+ for i, (start_ms, end_ms, score) in enumerate(music_segments, 1):
229
+ confidence = "🟒 High" if score > 0.7 else "🟑 Medium" if score > 0.5 else "🟠 Low"
230
+ segments_detail += f"| {i} | {start_ms/1000:.1f}s | {end_ms/1000:.1f}s | {score:.0%} {confidence} |\n"
231
+ else:
232
+ segments_detail = "\n### βœ… No music detected!\n"
233
+
234
+ report = f"""
235
+ ## πŸ“Š Processing Report
236
+
237
+ | Metric | Value |
238
+ |--------|-------|
239
+ | Original Duration | {original_duration:.2f}s |
240
+ | Clean Duration | {clean_duration:.2f}s |
241
+ | Removed Duration | {removed_duration:.2f}s ({(removed_duration/original_duration)*100:.1f}%) |
242
+ | Music Segments | {len(music_segments)} |
243
+ | Output Format | {output_format.upper()} |
244
+ {segments_detail}
245
+ """
246
+
247
+ logger.info(f"Complete: {original_duration:.1f}s -> {clean_duration:.1f}s")
248
+
249
+ return output_path, report
250
+
251
+ except Exception as e:
252
+ logger.error(f"Failed: {str(e)}")
253
+ return None, f"Error: {str(e)}"
254
+
255
+
256
+ logger.info("Starting CleanSpeech AI...")
257
+ remover = MusicRemover()
258
+
259
+
260
+ def process_audio(audio_file, output_format, progress=gr.Progress()):
261
+ if audio_file is None:
262
+ return None, "Please upload an audio file."
263
+
264
+ return remover.process(audio_file, output_format, progress=progress)
265
+
266
+
267
+ with gr.Blocks(title="CleanSpeech AI") as demo:
268
+
269
+ gr.Markdown("""
270
+ # 🎯 CleanSpeech AI
271
+ ### Remove Background Music from Audio
272
+
273
+ Upload your audio file and automatically detect and remove background music.
274
+ """)
275
+
276
+ with gr.Row():
277
+ with gr.Column(scale=1):
278
+ audio_input = gr.Audio(label="🎀 Upload Audio", type="filepath")
279
+ output_format = gr.Dropdown(
280
+ choices=["wav", "mp3", "ogg"],
281
+ value="wav",
282
+ label="πŸ“ Output Format"
283
+ )
284
+ process_btn = gr.Button("πŸš€ Remove Music", variant="primary", size="lg")
285
+
286
+ with gr.Column(scale=1):
287
+ audio_output = gr.Audio(label="πŸ”Š Cleaned Audio")
288
+ report_output = gr.Markdown()
289
+
290
+ process_btn.click(
291
+ fn=process_audio,
292
+ inputs=[audio_input, output_format],
293
+ outputs=[audio_output, report_output]
294
+ )
295
+ demo.queue()
296
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ torchaudio
4
+ gradio
5
+ librosa
6
+ soundfile
7
+ numpy
8
+ pydub