Aynursusuz commited on
Commit
85cf8e6
·
verified ·
1 Parent(s): 7249637

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -225
app.py DELETED
@@ -1,225 +0,0 @@
1
- import gradio as gr
2
- import torch
3
-
4
- try:
5
- import spaces
6
- ZERO_GPU = True
7
- except ImportError:
8
- ZERO_GPU = False
9
- import numpy as np
10
- from transformers import ASTForAudioClassification, AutoFeatureExtractor
11
- from pydub import AudioSegment
12
- import tempfile
13
- import logging
14
-
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
-
18
- MODEL_NAME = "Vyvo-Research/AST-Music-Classifier-1K"
19
- DETECTION_THRESHOLD = 0.50
20
- WINDOW_SIZE = 5.0
21
- HOP_SIZE = 5.0
22
-
23
- logger.info("Loading model on CPU...")
24
- model = ASTForAudioClassification.from_pretrained(MODEL_NAME)
25
- feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
26
- model.eval()
27
- logger.info("Model loaded")
28
-
29
-
30
- def load_audio(file_path: str, target_sr: int):
31
- audio = AudioSegment.from_file(file_path)
32
- audio = audio.set_channels(1).set_frame_rate(target_sr)
33
- samples = np.array(audio.get_array_of_samples()).astype(np.float32)
34
- samples = samples / np.iinfo(audio.array_type).max
35
- return samples, audio
36
-
37
-
38
- @torch.no_grad()
39
- def detect_music(audio_array, sample_rate):
40
- window_samples = int(WINDOW_SIZE * sample_rate)
41
- hop_samples = int(HOP_SIZE * sample_rate)
42
- total_samples = len(audio_array)
43
-
44
- music_segments = []
45
- last_was_music = False
46
- device = next(model.parameters()).device
47
- use_half = device.type == "cuda"
48
-
49
- for start in range(0, total_samples, hop_samples):
50
- end = min(start + window_samples, total_samples)
51
- segment = audio_array[start:end]
52
-
53
- if len(segment) < sample_rate:
54
- continue
55
-
56
- needs_padding = len(segment) < window_samples
57
- if needs_padding:
58
- segment = np.pad(segment, (0, window_samples - len(segment)), mode='constant')
59
-
60
- inputs = feature_extractor(
61
- segment,
62
- sampling_rate=sample_rate,
63
- return_tensors="pt",
64
- padding="max_length",
65
- truncation=True,
66
- max_length=1024
67
- )
68
-
69
- if use_half:
70
- inputs = {k: v.to(device).half() for k, v in inputs.items()}
71
- else:
72
- inputs = {k: v.to(device) for k, v in inputs.items()}
73
-
74
- outputs = model(**inputs)
75
- probs = torch.softmax(outputs.logits, dim=-1)
76
-
77
- pred_idx = torch.argmax(probs[0]).item()
78
- pred_label = model.config.id2label.get(pred_idx, "")
79
- pred_score = probs[0][pred_idx].item()
80
-
81
- is_music = "music" in pred_label.lower()
82
- is_uncertain = 0.40 <= pred_score <= 0.60
83
-
84
- start_sec = start / sample_rate
85
- end_sec = end / sample_rate
86
-
87
- if is_uncertain and needs_padding:
88
- if last_was_music:
89
- music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score))
90
- elif is_music and pred_score >= DETECTION_THRESHOLD:
91
- music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score))
92
- last_was_music = True
93
- else:
94
- last_was_music = False
95
-
96
- return music_segments
97
-
98
-
99
- def merge_segments(segments):
100
- if not segments:
101
- return []
102
-
103
- segments = sorted(segments, key=lambda x: x[0])
104
- merged = [segments[0]]
105
-
106
- for current in segments[1:]:
107
- last = merged[-1]
108
- if current[0] <= last[1]:
109
- merged[-1] = (last[0], max(last[1], current[1]), max(last[2], current[2]))
110
- else:
111
- merged.append(current)
112
-
113
- return merged
114
-
115
-
116
- def remove_music_segments(audio, segments):
117
- if not segments:
118
- return audio
119
-
120
- clean_parts = []
121
- last_end = 0
122
-
123
- for start_ms, end_ms, _ in segments:
124
- if start_ms > last_end:
125
- clean_parts.append(audio[last_end:start_ms])
126
- last_end = end_ms
127
-
128
- if last_end < len(audio):
129
- clean_parts.append(audio[last_end:])
130
-
131
- if not clean_parts:
132
- return AudioSegment.silent(duration=0)
133
-
134
- return sum(clean_parts)
135
-
136
-
137
- def build_report(original_dur, clean_dur, segments):
138
- removed = original_dur - clean_dur
139
- pct = (removed / original_dur) * 100 if original_dur > 0 else 0
140
-
141
- report = f"""## Processing Report
142
-
143
- | Metric | Value |
144
- |--------|-------|
145
- | Original Duration | {original_dur:.2f}s |
146
- | Clean Duration | {clean_dur:.2f}s |
147
- | Removed | {removed:.2f}s ({pct:.1f}%) |
148
- | Segments Found | {len(segments)} |
149
- | Output Format | WAV |
150
- """
151
-
152
- if segments:
153
- report += "\n### Detected Music Segments\n| # | Start | End | Confidence |\n|---|-------|-----|------------|\n"
154
- for i, (start_ms, end_ms, score) in enumerate(segments, 1):
155
- report += f"| {i} | {start_ms/1000:.1f}s | {end_ms/1000:.1f}s | {score:.0%} |\n"
156
- else:
157
- report += "\n*No music detected in this audio.*\n"
158
-
159
- return report
160
-
161
-
162
- @spaces.GPU if ZERO_GPU else lambda f: f
163
- def process_audio(audio_file, progress=gr.Progress()):
164
- if audio_file is None:
165
- return None, "Please upload an audio file."
166
-
167
- try:
168
- progress(0.1, desc="Preparing model...")
169
- if torch.cuda.is_available():
170
- model.to("cuda").half()
171
- torch.backends.cudnn.benchmark = True
172
-
173
- progress(0.2, desc="Loading audio...")
174
- sample_rate = feature_extractor.sampling_rate
175
- audio_array, audio = load_audio(audio_file, sample_rate)
176
- original_duration = len(audio) / 1000
177
-
178
- progress(0.4, desc="Detecting music...")
179
- segments = detect_music(audio_array, sample_rate)
180
- segments = merge_segments(segments)
181
-
182
- progress(0.7, desc="Processing...")
183
- clean_audio = remove_music_segments(audio, segments)
184
- clean_duration = len(clean_audio) / 1000
185
-
186
- progress(0.9, desc="Exporting...")
187
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
188
- clean_audio.export(f.name, format="wav")
189
- output_path = f.name
190
-
191
- progress(1.0, desc="Done")
192
- report = build_report(original_duration, clean_duration, segments)
193
-
194
- return output_path, report
195
-
196
- except Exception as e:
197
- logger.exception("Processing failed")
198
- return None, f"Error: {str(e)}"
199
-
200
-
201
- with gr.Blocks(title="CleanSpeech AI") as demo:
202
- gr.Markdown("""
203
- # CleanSpeech AI
204
- ### Remove Background Music from Audio
205
-
206
- Upload your audio file to automatically detect and remove background music.
207
- """)
208
-
209
- with gr.Row():
210
- with gr.Column(scale=1):
211
- audio_input = gr.Audio(label="Upload Audio", type="filepath")
212
- process_btn = gr.Button("Remove Music", variant="primary", size="lg")
213
-
214
- with gr.Column(scale=1):
215
- audio_output = gr.Audio(label="Cleaned Audio")
216
- report_output = gr.Markdown()
217
-
218
- process_btn.click(
219
- fn=process_audio,
220
- inputs=[audio_input],
221
- outputs=[audio_output, report_output]
222
- )
223
-
224
- demo.queue()
225
- demo.launch()