Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import os | |
| import tempfile | |
| import spaces | |
| from typing import Iterable | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| # --- Custom Theme Configuration --- | |
| class MidnightTheme(Soft): | |
| def __init__(self): | |
| super().__init__( | |
| # Using your specific text and button colors for the palettes | |
| primary_hue=colors.Color( | |
| name="brand", | |
| c50="#eef2ff", c100="#e0e7ff", c200="#c7d2fe", c300="#a5b4fc", | |
| c400="#818cf8", c500="#5248e9", c600="#4f46e5", c700="#4338ca", | |
| c800="#3730a3", c900="#312e81", c950="#1e1b4b" | |
| ), | |
| neutral_hue=colors.Color( | |
| name="dark_slate", | |
| c50="#f8fafc", c100="#f1f5f9", c200="#e2e8f0", c300="#cbd5e1", | |
| c400="#94a3b8", c500="#64748b", c600="#51748c", c700="#334155", # c600 is your secondary text | |
| c800="#20293c", c900="#10172b", c950="#030617" # c800-950 are your BG/Button darks | |
| ), | |
| font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"), | |
| ) | |
| super().set( | |
| # Backgrounds | |
| body_background_fill="#030617", | |
| block_background_fill="#10172b", | |
| block_border_color="#20293c", | |
| # Text Colors | |
| body_text_color="#cdd6e2", | |
| block_label_text_color="#51748c", | |
| block_title_text_color="#cdd6e2", | |
| # Buttons | |
| button_primary_background_fill="#5248e9", | |
| button_primary_text_color="white", | |
| button_secondary_background_fill="#20293c", | |
| button_secondary_text_color="#cdd6e2", | |
| # Inputs | |
| input_background_fill="#030617", | |
| input_border_color="#20293c", | |
| ) | |
| midnight_theme = MidnightTheme() | |
| # --- CSS for Layout Polish --- | |
| css = """ | |
| #container { max-width: 1000px; margin: auto; padding-top: 2rem; } | |
| #title-area { text-align: center; margin-bottom: 2rem; } | |
| .gradio-container { background-color: #030617 !important; } | |
| .output-audio { background-color: #030617 !important; } | |
| """ | |
| try: | |
| from sam_audio import SAMAudio, SAMAudioProcessor | |
| except ImportError as e: | |
| print(f"Warning: 'sam_audio' library not found. Please install it to use this app. Error: {e}") | |
| MODEL_ID = "facebook/sam-audio-large" | |
| DEFAULT_CHUNK_DURATION = 30.0 | |
| OVERLAP_DURATION = 2.0 | |
| MAX_DURATION_WITHOUT_CHUNKING = 30.0 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading {MODEL_ID} on {device}...") | |
| try: | |
| model = SAMAudio.from_pretrained(MODEL_ID,token=os.environ.get("HF_TOKEN")).to(device).eval() | |
| processor = SAMAudioProcessor.from_pretrained(MODEL_ID) | |
| print("β SAM-Audio loaded successfully.") | |
| except Exception as e: | |
| print(f"β Error loading SAM-Audio: {e}") | |
| def load_audio(file_path): | |
| """Load audio from file (supports both audio and video files).""" | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| return waveform, sample_rate | |
| def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration): | |
| """Split audio waveform into overlapping chunks.""" | |
| chunk_samples = int(chunk_duration * sample_rate) | |
| overlap_samples = int(overlap_duration * sample_rate) | |
| stride = chunk_samples - overlap_samples | |
| chunks = [] | |
| total_samples = waveform.shape[1] | |
| if total_samples <= chunk_samples: | |
| return [waveform] | |
| start = 0 | |
| while start < total_samples: | |
| end = min(start + chunk_samples, total_samples) | |
| chunk = waveform[:, start:end] | |
| chunks.append(chunk) | |
| if end >= total_samples: | |
| break | |
| start += stride | |
| return chunks | |
| def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration): | |
| """Merge audio chunks with crossfade on overlapping regions.""" | |
| if len(chunks) == 1: | |
| chunk = chunks[0] | |
| if chunk.dim() == 1: | |
| chunk = chunk.unsqueeze(0) | |
| return chunk | |
| overlap_samples = int(overlap_duration * sample_rate) | |
| processed_chunks = [] | |
| for chunk in chunks: | |
| if chunk.dim() == 1: | |
| chunk = chunk.unsqueeze(0) | |
| processed_chunks.append(chunk) | |
| result = processed_chunks[0] | |
| for i in range(1, len(processed_chunks)): | |
| prev_chunk = result | |
| next_chunk = processed_chunks[i] | |
| actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1]) | |
| if actual_overlap <= 0: | |
| result = torch.cat([prev_chunk, next_chunk], dim=1) | |
| continue | |
| fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device) | |
| fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device) | |
| prev_overlap = prev_chunk[:, -actual_overlap:] | |
| next_overlap = next_chunk[:, :actual_overlap] | |
| crossfaded = prev_overlap * fade_out + next_overlap * fade_in | |
| result = torch.cat([ | |
| prev_chunk[:, :-actual_overlap], | |
| crossfaded, | |
| next_chunk[:, actual_overlap:] | |
| ], dim=1) | |
| return result | |
| def save_audio(tensor, sample_rate): | |
| """Saves a tensor to a temporary WAV file and returns path.""" | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tensor = tensor.cpu() | |
| if tensor.dim() == 1: | |
| tensor = tensor.unsqueeze(0) | |
| torchaudio.save(tmp.name, tensor, sample_rate) | |
| return tmp.name | |
| def process_audio(file_path, text_prompt, chunk_duration_val, progress=gr.Progress()): | |
| global model, processor | |
| if model is None or processor is None: | |
| return None, None, "β Model not loaded correctly. Check logs." | |
| progress(0.05, desc="Checking inputs...") | |
| if not file_path: | |
| return None, None, "β Please upload an audio or video file." | |
| if not text_prompt or not text_prompt.strip(): | |
| return None, None, "β Please enter a text prompt." | |
| try: | |
| progress(0.15, desc="Loading audio...") | |
| waveform, sample_rate = load_audio(file_path) | |
| duration = waveform.shape[1] / sample_rate | |
| c_dur = chunk_duration_val if chunk_duration_val else DEFAULT_CHUNK_DURATION | |
| use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING | |
| if use_chunking: | |
| progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...") | |
| chunks = split_audio_into_chunks(waveform, sample_rate, c_dur, OVERLAP_DURATION) | |
| num_chunks = len(chunks) | |
| target_chunks = [] | |
| residual_chunks = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_progress = 0.2 + (i / num_chunks) * 0.6 | |
| progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...") | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| torchaudio.save(tmp.name, chunk, sample_rate) | |
| chunk_path = tmp.name | |
| try: | |
| inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device) | |
| with torch.inference_mode(): | |
| result = model.separate(inputs, predict_spans=False, reranking_candidates=1) | |
| target_chunks.append(result.target[0].detach().cpu()) | |
| residual_chunks.append(result.residual[0].detach().cpu()) | |
| finally: | |
| if os.path.exists(chunk_path): | |
| os.unlink(chunk_path) | |
| progress(0.85, desc="Merging chunks...") | |
| target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION) | |
| residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION) | |
| progress(0.95, desc="Saving results...") | |
| target_path = save_audio(target_merged, sample_rate) | |
| residual_path = save_audio(residual_merged, sample_rate) | |
| progress(1.0, desc="Done!") | |
| return target_path, residual_path, f"β Isolated '{text_prompt}' ({num_chunks} chunks)" | |
| else: | |
| progress(0.3, desc="Processing audio...") | |
| inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device) | |
| progress(0.6, desc="Separating sounds...") | |
| with torch.inference_mode(): | |
| result = model.separate(inputs, predict_spans=False, reranking_candidates=1) | |
| progress(0.9, desc="Saving results...") | |
| sr = processor.audio_sampling_rate | |
| target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sr) | |
| residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sr) | |
| progress(1.0, desc="Done!") | |
| return target_path, residual_path, f"β Isolated '{text_prompt}'" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, f"β Error: {str(e)}" | |
| def dummy_process(file, text, duration): # Placeholder for structure | |
| return None, None, "Processing..." | |
| with gr.Blocks(theme=midnight_theme, css=css) as demo: | |
| with gr.Column(elem_id="container"): | |
| # Header Section | |
| gr.Markdown( | |
| """ | |
| # ποΈ SAM-Audio Segmenter | |
| ### Isolate specific sounds using natural language descriptions. | |
| """, | |
| elem_id="title-area" | |
| ) | |
| with gr.Row(equal_height=True): | |
| # Left Side: Inputs | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### 1. Upload & Describe") | |
| input_file = gr.Audio(label="Input Audio Source", type="filepath") | |
| text_prompt = gr.Textbox( | |
| label="Target Sound", | |
| placeholder="e.g. 'electric guitar solo' or 'birds chirping'", | |
| info="What sound should we isolate from the background?" | |
| ) | |
| with gr.Accordion("Advanced Processing Settings", open=False): | |
| chunk_duration_slider = gr.Slider( | |
| minimum=10, maximum=60, value=30, step=5, | |
| label="Chunk Duration (s)", | |
| info="Shorter chunks save memory for long files." | |
| ) | |
| run_btn = gr.Button("π Start Separation", variant="primary") | |
| # Right Side: Outputs | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### 2. Results") | |
| output_target = gr.Audio(label="Isolated Result", type="filepath") | |
| output_residual = gr.Audio(label="Background / Remainder", type="filepath") | |
| status_out = gr.Textbox(label="Status Log", interactive=False, lines=2) | |
| # Examples Section at Bottom | |
| gr.Markdown("---") | |
| gr.Examples( | |
| examples=[ | |
| ["example_audio/speech.mp3", "Music", 30], | |
| ["example_audio/song.mp3", "Drum", 30] | |
| ], | |
| inputs=[input_file, text_prompt, chunk_duration_slider], | |
| label="Try an Example" | |
| ) | |
| # Event Binding | |
| run_btn.click( | |
| fn=process_audio, # Use your real function here | |
| inputs=[input_file, text_prompt, chunk_duration_slider], | |
| outputs=[output_target, output_residual, status_out] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |