Aynursusuz commited on
Commit
7249637
·
verified ·
1 Parent(s): 21973a8

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +225 -0
  2. packages.txt +1 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ try:
5
+ import spaces
6
+ ZERO_GPU = True
7
+ except ImportError:
8
+ ZERO_GPU = False
9
+ import numpy as np
10
+ from transformers import ASTForAudioClassification, AutoFeatureExtractor
11
+ from pydub import AudioSegment
12
+ import tempfile
13
+ import logging
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ MODEL_NAME = "Vyvo-Research/AST-Music-Classifier-1K"
19
+ DETECTION_THRESHOLD = 0.50
20
+ WINDOW_SIZE = 5.0
21
+ HOP_SIZE = 5.0
22
+
23
+ logger.info("Loading model on CPU...")
24
+ model = ASTForAudioClassification.from_pretrained(MODEL_NAME)
25
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
26
+ model.eval()
27
+ logger.info("Model loaded")
28
+
29
+
30
+ def load_audio(file_path: str, target_sr: int):
31
+ audio = AudioSegment.from_file(file_path)
32
+ audio = audio.set_channels(1).set_frame_rate(target_sr)
33
+ samples = np.array(audio.get_array_of_samples()).astype(np.float32)
34
+ samples = samples / np.iinfo(audio.array_type).max
35
+ return samples, audio
36
+
37
+
38
+ @torch.no_grad()
39
+ def detect_music(audio_array, sample_rate):
40
+ window_samples = int(WINDOW_SIZE * sample_rate)
41
+ hop_samples = int(HOP_SIZE * sample_rate)
42
+ total_samples = len(audio_array)
43
+
44
+ music_segments = []
45
+ last_was_music = False
46
+ device = next(model.parameters()).device
47
+ use_half = device.type == "cuda"
48
+
49
+ for start in range(0, total_samples, hop_samples):
50
+ end = min(start + window_samples, total_samples)
51
+ segment = audio_array[start:end]
52
+
53
+ if len(segment) < sample_rate:
54
+ continue
55
+
56
+ needs_padding = len(segment) < window_samples
57
+ if needs_padding:
58
+ segment = np.pad(segment, (0, window_samples - len(segment)), mode='constant')
59
+
60
+ inputs = feature_extractor(
61
+ segment,
62
+ sampling_rate=sample_rate,
63
+ return_tensors="pt",
64
+ padding="max_length",
65
+ truncation=True,
66
+ max_length=1024
67
+ )
68
+
69
+ if use_half:
70
+ inputs = {k: v.to(device).half() for k, v in inputs.items()}
71
+ else:
72
+ inputs = {k: v.to(device) for k, v in inputs.items()}
73
+
74
+ outputs = model(**inputs)
75
+ probs = torch.softmax(outputs.logits, dim=-1)
76
+
77
+ pred_idx = torch.argmax(probs[0]).item()
78
+ pred_label = model.config.id2label.get(pred_idx, "")
79
+ pred_score = probs[0][pred_idx].item()
80
+
81
+ is_music = "music" in pred_label.lower()
82
+ is_uncertain = 0.40 <= pred_score <= 0.60
83
+
84
+ start_sec = start / sample_rate
85
+ end_sec = end / sample_rate
86
+
87
+ if is_uncertain and needs_padding:
88
+ if last_was_music:
89
+ music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score))
90
+ elif is_music and pred_score >= DETECTION_THRESHOLD:
91
+ music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score))
92
+ last_was_music = True
93
+ else:
94
+ last_was_music = False
95
+
96
+ return music_segments
97
+
98
+
99
+ def merge_segments(segments):
100
+ if not segments:
101
+ return []
102
+
103
+ segments = sorted(segments, key=lambda x: x[0])
104
+ merged = [segments[0]]
105
+
106
+ for current in segments[1:]:
107
+ last = merged[-1]
108
+ if current[0] <= last[1]:
109
+ merged[-1] = (last[0], max(last[1], current[1]), max(last[2], current[2]))
110
+ else:
111
+ merged.append(current)
112
+
113
+ return merged
114
+
115
+
116
+ def remove_music_segments(audio, segments):
117
+ if not segments:
118
+ return audio
119
+
120
+ clean_parts = []
121
+ last_end = 0
122
+
123
+ for start_ms, end_ms, _ in segments:
124
+ if start_ms > last_end:
125
+ clean_parts.append(audio[last_end:start_ms])
126
+ last_end = end_ms
127
+
128
+ if last_end < len(audio):
129
+ clean_parts.append(audio[last_end:])
130
+
131
+ if not clean_parts:
132
+ return AudioSegment.silent(duration=0)
133
+
134
+ return sum(clean_parts)
135
+
136
+
137
+ def build_report(original_dur, clean_dur, segments):
138
+ removed = original_dur - clean_dur
139
+ pct = (removed / original_dur) * 100 if original_dur > 0 else 0
140
+
141
+ report = f"""## Processing Report
142
+
143
+ | Metric | Value |
144
+ |--------|-------|
145
+ | Original Duration | {original_dur:.2f}s |
146
+ | Clean Duration | {clean_dur:.2f}s |
147
+ | Removed | {removed:.2f}s ({pct:.1f}%) |
148
+ | Segments Found | {len(segments)} |
149
+ | Output Format | WAV |
150
+ """
151
+
152
+ if segments:
153
+ report += "\n### Detected Music Segments\n| # | Start | End | Confidence |\n|---|-------|-----|------------|\n"
154
+ for i, (start_ms, end_ms, score) in enumerate(segments, 1):
155
+ report += f"| {i} | {start_ms/1000:.1f}s | {end_ms/1000:.1f}s | {score:.0%} |\n"
156
+ else:
157
+ report += "\n*No music detected in this audio.*\n"
158
+
159
+ return report
160
+
161
+
162
+ @spaces.GPU if ZERO_GPU else lambda f: f
163
+ def process_audio(audio_file, progress=gr.Progress()):
164
+ if audio_file is None:
165
+ return None, "Please upload an audio file."
166
+
167
+ try:
168
+ progress(0.1, desc="Preparing model...")
169
+ if torch.cuda.is_available():
170
+ model.to("cuda").half()
171
+ torch.backends.cudnn.benchmark = True
172
+
173
+ progress(0.2, desc="Loading audio...")
174
+ sample_rate = feature_extractor.sampling_rate
175
+ audio_array, audio = load_audio(audio_file, sample_rate)
176
+ original_duration = len(audio) / 1000
177
+
178
+ progress(0.4, desc="Detecting music...")
179
+ segments = detect_music(audio_array, sample_rate)
180
+ segments = merge_segments(segments)
181
+
182
+ progress(0.7, desc="Processing...")
183
+ clean_audio = remove_music_segments(audio, segments)
184
+ clean_duration = len(clean_audio) / 1000
185
+
186
+ progress(0.9, desc="Exporting...")
187
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
188
+ clean_audio.export(f.name, format="wav")
189
+ output_path = f.name
190
+
191
+ progress(1.0, desc="Done")
192
+ report = build_report(original_duration, clean_duration, segments)
193
+
194
+ return output_path, report
195
+
196
+ except Exception as e:
197
+ logger.exception("Processing failed")
198
+ return None, f"Error: {str(e)}"
199
+
200
+
201
+ with gr.Blocks(title="CleanSpeech AI") as demo:
202
+ gr.Markdown("""
203
+ # CleanSpeech AI
204
+ ### Remove Background Music from Audio
205
+
206
+ Upload your audio file to automatically detect and remove background music.
207
+ """)
208
+
209
+ with gr.Row():
210
+ with gr.Column(scale=1):
211
+ audio_input = gr.Audio(label="Upload Audio", type="filepath")
212
+ process_btn = gr.Button("Remove Music", variant="primary", size="lg")
213
+
214
+ with gr.Column(scale=1):
215
+ audio_output = gr.Audio(label="Cleaned Audio")
216
+ report_output = gr.Markdown()
217
+
218
+ process_btn.click(
219
+ fn=process_audio,
220
+ inputs=[audio_input],
221
+ outputs=[audio_output, report_output]
222
+ )
223
+
224
+ demo.queue()
225
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ torchaudio
4
+ gradio
5
+ librosa
6
+ soundfile
7
+ numpy
8
+ pydub