Spaces:
Sleeping
Sleeping
| """ | |
| Gradio Music Tagging GUI with Fine-Tuned Llama Review Generation | |
| A web interface for tagging music and generating Pitchfork-style reviews using: | |
| 1. Zero-Shot: CLAP + MuLan models with pre-computed tag embeddings | |
| 2. MTG-Jamendo (Multi-Head): Trained MAEST classifier with separate heads for | |
| genre, instrument, and mood/theme | |
| 3. Review Generation: Fine-tuned Llama 3.2 3B model via transformers | |
| Also supports: | |
| - Vocal separation using Demucs | |
| - Lyric transcription using Whisper | |
| Usage: | |
| python app_with_llama_reviews.py | |
| Requirements: | |
| pip install -r requirements.txt | |
| pip install transformers accelerate # For Llama inference | |
| Note: This version loads the fine-tuned model directly from HuggingFace Hub using transformers. | |
| """ | |
| import os | |
| import warnings | |
| import torch | |
| import numpy as np | |
| import librosa | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from pathlib import Path | |
| # Suppress torchaudio deprecation warnings from Demucs internals | |
| warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio") | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| class Config: | |
| """Configuration for all inference methods.""" | |
| # Device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Zero-Shot paths (relative to script directory) | |
| clap_embeddings_path = "clap_tag_embeddings.npy" | |
| mulan_embeddings_path = "mulan_tag_embeddings.npy" | |
| tag_names_path = "musiccaps_tag_names.txt" | |
| # MTG-Jamendo Multi-Head paths (MAEST) | |
| multihead_checkpoint_path = "best_maest_model.pt" | |
| # MAEST settings for Multi-Head classification | |
| maest_model_name = "discogs-maest-30s-pw-129e" | |
| maest_layer = 6 # Intermediate layer for best performance (per MAEST paper) | |
| maest_sample_rate = 16000 | |
| maest_feature_dim = 2304 # CLS (768) + DIST (768) + avg tokens (768) | |
| max_duration = 30 # seconds | |
| # Llama Review Model settings (transformers) | |
| llama_hf_repo = "tventurella/llama-pitchfork-merged" # HuggingFace model repo | |
| llama_max_new_tokens = 800 | |
| # Whisper settings | |
| whisper_model_name = "turbo" # Options: tiny, base, small, medium, large-v3 | |
| whisper_device = None # None = auto-detect, or specify "cpu"/"cuda" | |
| # Demucs settings | |
| demucs_model_name = "htdemucs" # Hybrid Transformer Demucs v4 | |
| vocal_energy_threshold = 0.02 # Minimum RMS energy to consider vocals present | |
| max_lyrics_duration = 60 # Max amount of time for lyrics analysis | |
| config = Config() | |
| # Get script directory for relative paths | |
| SCRIPT_DIR = Path(__file__).parent.resolve() | |
| # ============================================================ | |
| # Model Classes | |
| # ============================================================ | |
| class MultiHeadMusicTagger(nn.Module): | |
| """Multi-head classifier for genre, instrument, and mood/theme using MAEST features.""" | |
| def __init__(self, input_dim=2304, num_genres=87, num_instruments=40, | |
| num_moods=56, hidden_dim=512, dropout=0.5): | |
| super().__init__() | |
| # Shared backbone (matches training architecture) | |
| self.backbone = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim * 2), # 2304 -> 1024 | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim * 2, hidden_dim), # 1024 -> 512 | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| ) | |
| head_input_dim = hidden_dim # 512 | |
| self.genre_head = nn.Sequential( | |
| nn.Linear(head_input_dim, head_input_dim // 2), # 512 -> 256 | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(head_input_dim // 2, num_genres) # 256 -> num_genres | |
| ) | |
| self.instrument_head = nn.Sequential( | |
| nn.Linear(head_input_dim, head_input_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(head_input_dim // 2, num_instruments) | |
| ) | |
| self.mood_head = nn.Sequential( | |
| nn.Linear(head_input_dim, head_input_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(head_input_dim // 2, num_moods) | |
| ) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| return { | |
| 'genre': self.genre_head(features), | |
| 'instrument': self.instrument_head(features), | |
| 'mood': self.mood_head(features) | |
| } | |
| # ============================================================ | |
| # Global Model Cache (lazy loading) | |
| # ============================================================ | |
| class ModelCache: | |
| """Lazy-loaded model cache to avoid loading models until needed.""" | |
| def __init__(self): | |
| self._zero_shot_models = None | |
| self._zero_shot_embeddings = None | |
| self._zero_shot_tag_names = None | |
| # Multi-head model cache (MAEST) | |
| self._multihead_model = None | |
| self._multihead_tags = None | |
| self._maest_model = None | |
| # Lyrics models cache | |
| self._demucs_model = None | |
| self._whisper_model = None | |
| # Llama review model | |
| self._llama_model = None | |
| def get_zero_shot_models(self): | |
| """Load and cache Zero-Shot models (CLAP + MuLan).""" | |
| if self._zero_shot_models is None: | |
| print("Loading Zero-Shot models (this may take a moment)...") | |
| try: | |
| from muq import MuQMuLan | |
| from transformers import ClapModel, ClapProcessor | |
| except ImportError as e: | |
| raise ImportError( | |
| f"Missing dependencies for Zero-Shot tagging: {e}\n" | |
| "Install with: pip install muq laion-clap transformers" | |
| ) | |
| print(" Loading MuQ-MuLan...") | |
| mulan_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large") | |
| mulan_model = mulan_model.to(config.device).eval() | |
| print(" Loading CLAP...") | |
| clap_model = ClapModel.from_pretrained("laion/larger_clap_music_and_speech") | |
| clap_model = clap_model.to(config.device).eval() | |
| clap_processor = ClapProcessor.from_pretrained("laion/larger_clap_music_and_speech") | |
| self._zero_shot_models = (mulan_model, clap_model, clap_processor) | |
| print(" Zero-Shot models loaded!") | |
| return self._zero_shot_models | |
| def get_zero_shot_embeddings(self): | |
| """Load and cache pre-computed tag embeddings.""" | |
| if self._zero_shot_embeddings is None: | |
| clap_path = SCRIPT_DIR / config.clap_embeddings_path | |
| mulan_path = SCRIPT_DIR / config.mulan_embeddings_path | |
| if not clap_path.exists() or not mulan_path.exists(): | |
| raise FileNotFoundError( | |
| f"Pre-computed embeddings not found!\n" | |
| f"Expected:\n" | |
| f" - {clap_path}\n" | |
| f" - {mulan_path}\n\n" | |
| f"Run create_embeddings.py first to generate these files." | |
| ) | |
| print("Loading pre-computed embeddings...") | |
| clap_embeddings = np.load(str(clap_path)) | |
| mulan_embeddings = np.load(str(mulan_path)) | |
| self._zero_shot_embeddings = (clap_embeddings, mulan_embeddings) | |
| print(f" Loaded embeddings: CLAP {clap_embeddings.shape}, MuLan {mulan_embeddings.shape}") | |
| return self._zero_shot_embeddings | |
| def get_zero_shot_tag_names(self): | |
| """Load and cache tag names for Zero-Shot tagging.""" | |
| if self._zero_shot_tag_names is None: | |
| tag_path = SCRIPT_DIR / config.tag_names_path | |
| if not tag_path.exists(): | |
| raise FileNotFoundError( | |
| f"Tag names file not found: {tag_path}\n" | |
| "Run create_embeddings.py first to generate this file." | |
| ) | |
| with open(tag_path, 'r', encoding='utf-8') as f: | |
| self._zero_shot_tag_names = [line.strip() for line in f if line.strip()] | |
| print(f" Loaded {len(self._zero_shot_tag_names)} tag names") | |
| return self._zero_shot_tag_names | |
| def get_multihead_model(self): | |
| """Load and cache MTG-Jamendo multi-head trained model (MAEST-based).""" | |
| if self._multihead_model is None: | |
| checkpoint_path = SCRIPT_DIR / config.multihead_checkpoint_path | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError( | |
| f"Multi-head MAEST checkpoint not found: {checkpoint_path}\n" | |
| "Train the model first using the MAEST training notebook." | |
| ) | |
| print("Loading MAEST multi-head model...") | |
| checkpoint = torch.load(str(checkpoint_path), map_location=config.device, weights_only=False) | |
| genre_tags = checkpoint['genre_tags'] | |
| instrument_tags = checkpoint['instrument_tags'] | |
| mood_tags = checkpoint['mood_tags'] | |
| # Get config from checkpoint if available | |
| model_config = checkpoint.get('config', {}) | |
| maest_feature_dim = model_config.get('maest_feature_dim', config.maest_feature_dim) | |
| hidden_dim = model_config.get('hidden_dim', 512) # Training used 512 | |
| # Update global config with checkpoint settings if available | |
| if 'maest_layer' in model_config: | |
| config.maest_layer = model_config['maest_layer'] | |
| model = MultiHeadMusicTagger( | |
| input_dim=maest_feature_dim, | |
| num_genres=len(genre_tags), | |
| num_instruments=len(instrument_tags), | |
| num_moods=len(mood_tags), | |
| hidden_dim=hidden_dim | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(config.device).eval() | |
| self._multihead_model = model | |
| self._multihead_tags = { | |
| 'genre': genre_tags, | |
| 'instrument': instrument_tags, | |
| 'mood': mood_tags | |
| } | |
| print(f" Loaded MAEST multi-head model:") | |
| print(f" Genres: {len(genre_tags)}") | |
| print(f" Instruments: {len(instrument_tags)}") | |
| print(f" Moods: {len(mood_tags)}") | |
| return self._multihead_model, self._multihead_tags | |
| def get_maest_model(self): | |
| """Load and cache MAEST model for multi-head feature extraction.""" | |
| if self._maest_model is None: | |
| try: | |
| from maest import get_maest | |
| except ImportError as e: | |
| raise ImportError( | |
| f"Missing maest library: {e}\n" | |
| "Install with: pip install git+https://github.com/palonso/MAEST.git" | |
| ) | |
| print(f"Loading MAEST model ({config.maest_model_name})...") | |
| self._maest_model = get_maest(arch=config.maest_model_name) | |
| # Keep MAEST on CPU for stable STFT processing (avoids device mismatch) | |
| self._maest_model = self._maest_model.eval() | |
| print(" MAEST model loaded (CPU mode for stable STFT)!") | |
| return self._maest_model | |
| def get_demucs_model(self): | |
| """Load and cache Demucs model for vocal separation.""" | |
| if self._demucs_model is None: | |
| try: | |
| from demucs import pretrained | |
| except ImportError as e: | |
| raise ImportError( | |
| f"Missing demucs library: {e}\n" | |
| "Install with: pip install demucs" | |
| ) | |
| print("Loading Demucs model (this may take a moment)...") | |
| self._demucs_model = pretrained.get_model(config.demucs_model_name) | |
| self._demucs_model = self._demucs_model.to(config.device).eval() | |
| print(f" Demucs {config.demucs_model_name} loaded!") | |
| return self._demucs_model | |
| def get_whisper_model(self): | |
| """Load and cache Whisper model for lyric transcription.""" | |
| if self._whisper_model is None: | |
| try: | |
| import whisper | |
| except ImportError as e: | |
| raise ImportError( | |
| f"Missing whisper library: {e}\n" | |
| "Install with: pip install openai-whisper" | |
| ) | |
| print(f"Loading Whisper model '{config.whisper_model_name}' (this may take a moment)...") | |
| device = config.whisper_device if config.whisper_device else config.device | |
| self._whisper_model = whisper.load_model(config.whisper_model_name, device=device) | |
| print(f" Whisper {config.whisper_model_name} loaded!") | |
| return self._whisper_model | |
| def get_llama_model(self): | |
| """Load Llama model using transformers for review generation.""" | |
| if self._llama_model is None: | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| except ImportError as e: | |
| raise ImportError( | |
| f"Missing transformers library: {e}\n" | |
| "Install with: pip install transformers accelerate" | |
| ) | |
| print(f"Loading Llama model from {config.llama_hf_repo}...") | |
| print(" (This may take a moment...)") | |
| tokenizer = AutoTokenizer.from_pretrained(config.llama_hf_repo) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.llama_hf_repo, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| self._llama_model = (model, tokenizer) | |
| print(" Llama model loaded!") | |
| return self._llama_model | |
| # Global model cache | |
| model_cache = ModelCache() | |
| # ============================================================ | |
| # Inference Functions | |
| # ============================================================ | |
| def tag_audio_zero_shot(audio_path: str, top_k: int = 20, normalization: str = "individual"): | |
| """ | |
| Tag audio using Zero-Shot CLAP + MuLan approach. | |
| Args: | |
| audio_path: Path to audio file | |
| top_k: Number of top tags to return | |
| normalization: "mulan_only", "global", or "individual" | |
| Returns: | |
| List of (tag_name, confidence) tuples | |
| """ | |
| mulan_model, clap_model, clap_processor = model_cache.get_zero_shot_models() | |
| clap_embeddings, mulan_embeddings = model_cache.get_zero_shot_embeddings() | |
| tag_names = model_cache.get_zero_shot_tag_names() | |
| min_len = min(len(tag_names), len(clap_embeddings), len(mulan_embeddings)) | |
| if len(tag_names) != min_len: | |
| tag_names = tag_names[:min_len] | |
| clap_embeddings = clap_embeddings[:min_len] | |
| mulan_embeddings = mulan_embeddings[:min_len] | |
| wav_mulan, _ = librosa.load(audio_path, sr=24000, mono=True) | |
| wavs = torch.tensor(wav_mulan, dtype=torch.float32).unsqueeze(0).to(config.device) | |
| mulan_output = mulan_model(wavs=wavs) | |
| if isinstance(mulan_output, torch.Tensor): | |
| mulan_audio_embed = mulan_output | |
| else: | |
| mulan_audio_embed = mulan_output.pooler_output | |
| wav_clap, _ = librosa.load(audio_path, sr=48000, mono=True) | |
| inputs = clap_processor(audio=wav_clap, sampling_rate=48000, return_tensors="pt").to(config.device) | |
| clap_output = clap_model.get_audio_features(**inputs) | |
| if isinstance(clap_output, torch.Tensor): | |
| clap_audio_embed = clap_output | |
| else: | |
| clap_audio_embed = clap_output.pooler_output | |
| mulan_text_e = torch.tensor(mulan_embeddings, dtype=torch.float32).to(config.device) | |
| clap_text_e = torch.tensor(clap_embeddings, dtype=torch.float32).to(config.device) | |
| mulan_sims = F.cosine_similarity(mulan_audio_embed, mulan_text_e, dim=1) | |
| clap_sims = F.cosine_similarity(clap_audio_embed, clap_text_e, dim=1) | |
| if normalization == "mulan_only": | |
| combined = mulan_sims | |
| elif normalization == "global": | |
| all_sims = torch.cat([mulan_sims, clap_sims]) | |
| g_min, g_max = all_sims.min(), all_sims.max() | |
| mulan_norm = (mulan_sims - g_min) / (g_max - g_min + 1e-8) | |
| clap_norm = (clap_sims - g_min) / (g_max - g_min + 1e-8) | |
| combined = 0.5 * mulan_norm + 0.5 * clap_norm | |
| else: # individual | |
| mulan_norm = (mulan_sims - mulan_sims.min()) / (mulan_sims.max() - mulan_sims.min() + 1e-8) | |
| clap_norm = (clap_sims - clap_sims.min()) / (clap_sims.max() - clap_sims.min() + 1e-8) | |
| combined = 0.5 * mulan_norm + 0.5 * clap_norm | |
| top_scores, top_idx = torch.topk(combined, k=min(top_k, len(tag_names))) | |
| predictions = [] | |
| for i, idx in enumerate(top_idx): | |
| tag = tag_names[idx.item()] | |
| score = top_scores[i].item() | |
| predictions.append((tag, score)) | |
| return predictions | |
| def tag_audio_multihead(audio_path: str, top_k: int = 5): | |
| """ | |
| Tag audio using trained MTG-Jamendo multi-head MAEST classifier. | |
| Returns predictions for genre, instrument, and mood/theme. | |
| Args: | |
| audio_path: Path to audio file | |
| top_k: Number of top tags to return per category | |
| Returns: | |
| Dict with 'genre', 'instrument', 'mood' keys, each containing | |
| list of (tag_name, confidence) tuples | |
| """ | |
| model, tags = model_cache.get_multihead_model() | |
| maest_model = model_cache.get_maest_model() | |
| # Load audio at 16kHz for MAEST | |
| wav, sr = librosa.load(audio_path, sr=config.maest_sample_rate, mono=True) | |
| # Ensure exactly 30 seconds (pad or truncate) | |
| max_samples = config.maest_sample_rate * config.max_duration | |
| if len(wav) > max_samples: | |
| wav = wav[:max_samples] | |
| elif len(wav) < max_samples: | |
| # Pad with zeros if shorter | |
| wav = np.pad(wav, (0, max_samples - len(wav)), mode='constant') | |
| # Keep tensor on CPU for MAEST's internal STFT processing | |
| wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) | |
| # Get MAEST embeddings from intermediate layer | |
| _, embeddings = maest_model(wav_tensor, transformer_block=config.maest_layer) | |
| # Move embeddings to classifier device | |
| embeddings = embeddings.to(config.device) | |
| # Pass embeddings through our multi-head classifier | |
| logits = model(embeddings) | |
| genre_probs = torch.softmax(logits['genre'], dim=1).squeeze(0).cpu().numpy() | |
| instrument_probs = torch.softmax(logits['instrument'], dim=1).squeeze(0).cpu().numpy() | |
| mood_probs = torch.softmax(logits['mood'], dim=1).squeeze(0).cpu().numpy() | |
| genre_top = np.argsort(genre_probs)[::-1][:top_k] | |
| instrument_top = np.argsort(instrument_probs)[::-1][:top_k] | |
| mood_top = np.argsort(mood_probs)[::-1][:top_k] | |
| return { | |
| 'genre': [(tags['genre'][i], float(genre_probs[i])) for i in genre_top], | |
| 'instrument': [(tags['instrument'][i], float(instrument_probs[i])) for i in instrument_top], | |
| 'mood': [(tags['mood'][i], float(mood_probs[i])) for i in mood_top] | |
| } | |
| def tag_audio_combined(audio_path: str, top_k: int = 5, normalization: str = "individual"): | |
| """ | |
| Tag audio using BOTH Zero-Shot (CLAP + MuLan) AND Multi-Head (MAEST) methods. | |
| Runs both methods and combines results for comprehensive tagging. | |
| Args: | |
| audio_path: Path to audio file | |
| top_k: Number of top tags to return per category | |
| normalization: Normalization method for zero-shot | |
| Returns: | |
| Dict with 'zero_shot', 'multihead' keys containing respective predictions | |
| """ | |
| results = { | |
| 'zero_shot': None, | |
| 'multihead': None, | |
| 'zero_shot_error': None, | |
| 'multihead_error': None, | |
| } | |
| # Run Zero-Shot (CLAP + MuLan) | |
| try: | |
| print(" Running Zero-Shot (CLAP + MuLan) tagging...") | |
| zero_shot_predictions = tag_audio_zero_shot( | |
| audio_path, | |
| top_k=top_k, | |
| normalization=normalization | |
| ) | |
| results['zero_shot'] = zero_shot_predictions | |
| print(f" Got {len(zero_shot_predictions)} zero-shot tags") | |
| except Exception as e: | |
| print(f" Zero-Shot error: {e}") | |
| results['zero_shot_error'] = str(e) | |
| # Run Multi-Head (MAEST) | |
| try: | |
| print(" Running Multi-Head (MAEST) tagging...") | |
| multihead_predictions = tag_audio_multihead(audio_path, top_k=top_k) | |
| results['multihead'] = multihead_predictions | |
| total_multihead = sum(len(v) for v in multihead_predictions.values()) | |
| print(f" Got {total_multihead} multi-head tags") | |
| except Exception as e: | |
| print(f" Multi-Head error: {e}") | |
| results['multihead_error'] = str(e) | |
| return results | |
| def separate_vocals(audio_path: str): | |
| """ | |
| Separate vocals from audio using Demucs with librosa loading. | |
| Args: | |
| audio_path: Path to audio file | |
| Returns: | |
| Tuple of (vocal_audio_np, sample_rate, all_sources_dict) | |
| """ | |
| from demucs.apply import apply_model | |
| model = model_cache.get_demucs_model() | |
| print(f" Loading audio with librosa at {model.samplerate}Hz (max {config.max_lyrics_duration}s)...") | |
| wav, sr = librosa.load(audio_path, sr=model.samplerate, mono=False, duration=config.max_lyrics_duration) | |
| # Ensure stereo (Demucs expects 2 channels) | |
| if wav.ndim == 1: | |
| wav = np.stack([wav, wav]) | |
| elif wav.shape[0] == 1: | |
| wav = np.repeat(wav, 2, axis=0) | |
| elif wav.shape[0] > 2: | |
| wav = wav[:2] | |
| wav_tensor = torch.from_numpy(wav).float().to(config.device) | |
| print(" Separating sources with Demucs...") | |
| sources = apply_model(model, wav_tensor.unsqueeze(0), device=config.device)[0] | |
| source_names = model.sources | |
| sources_dict = {} | |
| for i, name in enumerate(source_names): | |
| sources_dict[name] = sources[i].cpu().numpy() | |
| print(f" Separated into: {', '.join(source_names)}") | |
| vocals = sources_dict['vocals'] | |
| return vocals, model.samplerate, sources_dict | |
| def has_vocals(audio_path: str = None, vocal_audio: np.ndarray = None, | |
| threshold: float = None) -> tuple: | |
| """ | |
| Detect if audio contains vocals by checking energy in vocal stem. | |
| """ | |
| if threshold is None: | |
| threshold = config.vocal_energy_threshold | |
| if vocal_audio is None: | |
| if audio_path is None: | |
| raise ValueError("Must provide either audio_path or vocal_audio") | |
| vocals, sr, _ = separate_vocals(audio_path) | |
| else: | |
| vocals = vocal_audio | |
| sr = config.maest_sample_rate | |
| vocal_energy = np.sqrt(np.mean(vocals ** 2)) | |
| has_vocal = vocal_energy > threshold | |
| print(f" Vocal energy: {vocal_energy:.4f} (threshold: {threshold:.4f})") | |
| print(f" Vocals detected: {'Yes' if has_vocal else 'No'}") | |
| return has_vocal, vocal_energy, vocals, sr | |
| def transcribe_lyrics(audio_path: str) -> dict: | |
| """ | |
| Transcribe lyrics from audio using Whisper. | |
| """ | |
| model = model_cache.get_whisper_model() | |
| print(" Transcribing lyrics with Whisper...") | |
| result = model.transcribe(audio_path) | |
| print(f" Detected language: {result.get('language', 'unknown')}") | |
| print(f" Transcribed {len(result.get('segments', []))} segments") | |
| return result | |
| def analyze_audio_with_lyrics(audio_path: str, method: str, normalization: str, | |
| top_k: int, include_lyrics: bool = True, progress=gr.Progress()): | |
| """ | |
| Complete pipeline: tag audio, detect vocals, transcribe lyrics. | |
| """ | |
| if audio_path is None: | |
| return "Please upload an audio file.", "", "", False | |
| print("\n" + "="*60) | |
| print("STARTING ANALYSIS") | |
| print("="*60) | |
| try: | |
| if include_lyrics: | |
| total_steps = 4 | |
| else: | |
| total_steps = 2 | |
| current_step = 0 | |
| # Step 1: Tag the audio | |
| current_step += 1 | |
| progress(current_step / total_steps, desc="Tagging Audio...") | |
| tags_result, tags_str = analyze_audio(audio_path, method, normalization, top_k) | |
| lyrics_text = "" | |
| vocal_detected = False | |
| # Step 2: Check for vocals and transcribe if requested | |
| if include_lyrics: | |
| try: | |
| current_step += 1 | |
| progress(current_step / total_steps, desc="Separating vocals using Demucs...") | |
| vocals, sr, _ = separate_vocals(audio_path) | |
| has_vocal, energy, vocals, sr = has_vocals(vocal_audio=vocals, threshold=config.vocal_energy_threshold) | |
| vocal_detected = has_vocal | |
| if has_vocal: | |
| current_step += 1 | |
| progress(current_step / total_steps, desc="Transcribing lyrics...") | |
| result = transcribe_lyrics(audio_path=audio_path) | |
| lyrics_text = result['text'].strip() | |
| if lyrics_text: | |
| print(f"\n Lyrics transcribed: {len(lyrics_text)} characters") | |
| else: | |
| print("\n No lyrics detected (vocals may be humming/instrumental)") | |
| lyrics_text = "[Instrumental/No clear lyrics detected]" | |
| else: | |
| print("\n No vocals detected - skipping transcription") | |
| lyrics_text = "[Instrumental - No vocals detected]" | |
| except Exception as e: | |
| print(f"\n Error during vocal processing: {str(e)}") | |
| lyrics_text = f"[Error during lyric transcription: {str(e)}]" | |
| vocal_detected = False | |
| else: | |
| print("\n[Skipping lyric transcription]") | |
| progress(1.0, desc="Analysis complete!") | |
| print("\n" + "="*60) | |
| print("ANALYSIS COMPLETE") | |
| print("="*60 + "\n") | |
| return tags_result, tags_str, lyrics_text, vocal_detected | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"**Error during analysis:**\n\n{str(e)}\n\n{traceback.format_exc()}" | |
| progress(1.0, desc="Error occurred.") | |
| return error_msg, "", "", False | |
| # ============================================================ | |
| # Review Generation (Fine-tuned Llama) | |
| # ============================================================ | |
| def generate_review_llama(tags_str: str, lyrics_str: str = "", score: float = 7.0, | |
| artist: str = "Unknown Artist", title: str = "Unknown Album") -> str: | |
| """ | |
| Generate a Pitchfork-style review using fine-tuned Llama model via transformers. | |
| Args: | |
| tags_str: JSON string with structured tag data (genres, instruments, moods, qualitative) | |
| lyrics_str: Transcribed lyrics (optional) | |
| score: User-provided score (1-10) | |
| artist: Artist name (optional, can use "Unknown Artist") | |
| title: Album title (optional, can use "Unknown Album") | |
| Returns: | |
| Generated review text | |
| """ | |
| import json | |
| try: | |
| model, tokenizer = model_cache.get_llama_model() | |
| except ImportError as e: | |
| return ( | |
| f"**Error:** {str(e)}\n\n" | |
| "Install with: `pip install transformers accelerate`" | |
| ) | |
| # Parse structured tags | |
| genres = [] | |
| instruments = [] | |
| moods = [] | |
| qualitative = [] | |
| if tags_str: | |
| try: | |
| tags_data = json.loads(tags_str) | |
| genres = tags_data.get('genres', []) | |
| instruments = tags_data.get('instruments', []) | |
| moods = tags_data.get('moods', []) | |
| qualitative = tags_data.get('qualitative', []) | |
| except json.JSONDecodeError: | |
| # Fallback: treat as comma-separated string (legacy format) | |
| qualitative = [t.strip() for t in tags_str.split(",")] | |
| # Get top 2 genres for the main prompt | |
| top_genres = genres[:2] if genres else ["rock"] | |
| genre_str = " / ".join(top_genres) | |
| # Build the prompt with explicit structured context | |
| prompt = ( | |
| f"Write a Pitchfork-style review for '{title}', a {genre_str}-genre song by {artist}. " | |
| f"Score: {score}/10." | |
| ) | |
| # Add structured musical characteristics | |
| context_parts = [] | |
| if instruments: | |
| instrument_str = ", ".join(instruments[:5]) # Top 5 instruments | |
| context_parts.append(f"Comment on the following instrumentation: {instrument_str}") | |
| if moods: | |
| mood_str = ", ".join(moods[:5]) # Top 5 moods | |
| context_parts.append(f"Note the Mood/Theme of the song: {mood_str}") | |
| if qualitative: | |
| qual_str = ", ".join(qualitative[:10]) # Top 10 qualitative tags | |
| context_parts.append(f"And be sure to incorporate the following contexual and qualitative characteristics: {qual_str}") | |
| if context_parts: | |
| prompt += "\n\n" + "\n".join(context_parts) | |
| # Add lyrics if available | |
| if lyrics_str and not lyrics_str.startswith("["): | |
| prompt += f"\n\nLyrics to incorporate in your analysis:\n{lyrics_str}" | |
| print(f"\nGenerating review with transformers...") | |
| print(f" Model: {config.llama_hf_repo}") | |
| print(f" Score: {score}/10") | |
| print(f" Genre: {genre_str}") | |
| print(f" Instruments: {instruments[:3] if instruments else 'N/A'}") | |
| print(f" Moods: {moods[:3] if moods else 'N/A'}") | |
| try: | |
| # Format as chat message for Llama instruct model | |
| messages = [{"role": "user", "content": prompt}] | |
| input_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=config.llama_max_new_tokens, | |
| temperature=0.6, | |
| top_p=0.85, | |
| top_k=50, # Add top-k filtering | |
| repetition_penalty=1.15, # Reduce repetition | |
| no_repeat_ngram_size=3, # Prevent 3-gram repetition | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode and extract only the new generated text | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the input prompt from the output | |
| if "assistant" in generated_text.lower(): | |
| generated_text = generated_text.split("assistant")[-1].strip() | |
| # Format output with score | |
| output = f"**Score: {score}/10** | Genre: {genre_str}\n\n{generated_text}" | |
| print(f" Generated {len(generated_text)} characters") | |
| return output | |
| except Exception as e: | |
| import traceback | |
| return f"**Error generating review:** {str(e)}\n\n{traceback.format_exc()}" | |
| # ============================================================ | |
| # Gradio Interface Functions | |
| # ============================================================ | |
| def format_predictions(predictions: list, method: str) -> str: | |
| """Format single-head predictions as a readable string.""" | |
| if not predictions: | |
| return "No predictions available." | |
| lines = [f"## {method} Results\n"] | |
| for i, (tag, score) in enumerate(predictions, 1): | |
| bar_length = int(score * 30) | |
| bar = "█" * bar_length + "░" * (30 - bar_length) | |
| lines.append(f"{i:2d}. **{tag}** - {score*100:.1f}% `{bar}`") | |
| return "\n".join(lines) | |
| def format_multihead_predictions(predictions: dict) -> str: | |
| """Format multi-head predictions as a readable string.""" | |
| if not predictions: | |
| return "No predictions available." | |
| lines = ["## Multi-Head Tagging Results\n"] | |
| lines.append("### Genre") | |
| for i, (tag, score) in enumerate(predictions['genre'], 1): | |
| bar_length = int(score * 30) | |
| bar = "█" * bar_length + "░" * (30 - bar_length) | |
| lines.append(f"{i}. **{tag}** - {score*100:.1f}% `{bar}`") | |
| lines.append("") | |
| lines.append("### Instrument") | |
| for i, (tag, score) in enumerate(predictions['instrument'], 1): | |
| bar_length = int(score * 30) | |
| bar = "█" * bar_length + "░" * (30 - bar_length) | |
| lines.append(f"{i}. **{tag}** - {score*100:.1f}% `{bar}`") | |
| lines.append("") | |
| lines.append("### Mood / Theme") | |
| for i, (tag, score) in enumerate(predictions['mood'], 1): | |
| bar_length = int(score * 30) | |
| bar = "█" * bar_length + "░" * (30 - bar_length) | |
| lines.append(f"{i}. **{tag}** - {score*100:.1f}% `{bar}`") | |
| return "\n".join(lines) | |
| def format_combined_predictions(results: dict) -> str: | |
| """Format combined predictions from both Zero-Shot and Multi-Head methods.""" | |
| lines = ["## Combined Tagging Results\n"] | |
| # Zero-Shot Results | |
| lines.append("---") | |
| lines.append("### Zero-Shot Tags (CLAP + MuLan)") | |
| lines.append("*Qualitative descriptors from open-vocabulary models*\n") | |
| if results.get('zero_shot'): | |
| for i, (tag, score) in enumerate(results['zero_shot'], 1): | |
| bar_length = int(score * 30) | |
| bar = "█" * bar_length + "░" * (30 - bar_length) | |
| lines.append(f"{i:2d}. **{tag}** - {score*100:.1f}% `{bar}`") | |
| elif results.get('zero_shot_error'): | |
| lines.append(f"*Error: {results['zero_shot_error']}*") | |
| else: | |
| lines.append("*Not available*") | |
| # Multi-Head Results | |
| lines.append("") | |
| lines.append("---") | |
| lines.append("### Multi-Head Tags (MAEST)") | |
| lines.append("*Structured predictions for genre, instrument, and mood*\n") | |
| if results.get('multihead'): | |
| multihead = results['multihead'] | |
| lines.append("#### Genre") | |
| for i, (tag, score) in enumerate(multihead['genre'], 1): | |
| bar_length = int(score * 30) | |
| bar = "█" * bar_length + "░" * (30 - bar_length) | |
| lines.append(f"{i}. **{tag}** - {score*100:.1f}% `{bar}`") | |
| lines.append("") | |
| lines.append("#### Instrument") | |
| for i, (tag, score) in enumerate(multihead['instrument'], 1): | |
| bar_length = int(score * 30) | |
| bar = "█" * bar_length + "░" * (30 - bar_length) | |
| lines.append(f"{i}. **{tag}** - {score*100:.1f}% `{bar}`") | |
| lines.append("") | |
| lines.append("#### Mood / Theme") | |
| for i, (tag, score) in enumerate(multihead['mood'], 1): | |
| bar_length = int(score * 30) | |
| bar = "█" * bar_length + "░" * (30 - bar_length) | |
| lines.append(f"{i}. **{tag}** - {score*100:.1f}% `{bar}`") | |
| elif results.get('multihead_error'): | |
| lines.append(f"*Error: {results['multihead_error']}*") | |
| else: | |
| lines.append("*Not available*") | |
| return "\n".join(lines) | |
| def _extract_combined_tag_string(results: dict) -> str: | |
| """ | |
| Extract structured tag data as JSON string from combined predictions. | |
| Returns JSON with separate fields for genres, instruments, moods, and qualitative tags. | |
| """ | |
| import json | |
| structured = { | |
| 'genres': [], | |
| 'instruments': [], | |
| 'moods': [], | |
| 'qualitative': [], # From zero-shot CLAP/MuLan | |
| } | |
| # Add multi-head tags (structured) | |
| if results.get('multihead'): | |
| structured['genres'] = [tag for tag, _ in results['multihead'].get('genre', [])] | |
| structured['instruments'] = [tag for tag, _ in results['multihead'].get('instrument', [])] | |
| structured['moods'] = [tag for tag, _ in results['multihead'].get('mood', [])] | |
| # Add zero-shot tags (qualitative descriptors) | |
| if results.get('zero_shot'): | |
| structured['qualitative'] = [tag for tag, _ in results['zero_shot']] | |
| return json.dumps(structured) | |
| def _extract_tag_string(predictions, is_multihead=False): | |
| """Extract structured tag data as JSON string from predictions.""" | |
| import json | |
| structured = { | |
| 'genres': [], | |
| 'instruments': [], | |
| 'moods': [], | |
| 'qualitative': [], | |
| } | |
| if is_multihead: | |
| structured['genres'] = [tag for tag, _ in predictions.get('genre', [])] | |
| structured['instruments'] = [tag for tag, _ in predictions.get('instrument', [])] | |
| structured['moods'] = [tag for tag, _ in predictions.get('mood', [])] | |
| else: | |
| # Zero-shot tags go to qualitative | |
| structured['qualitative'] = [tag for tag, _ in predictions] | |
| return json.dumps(structured) | |
| def analyze_audio(audio_file, method: str, normalization: str, top_k: int): | |
| """ | |
| Main analysis function called by Gradio interface. | |
| """ | |
| if audio_file is None: | |
| return "Please upload an audio file.", "" | |
| try: | |
| norm_map = { | |
| "Individual (recommended)": "individual", | |
| "Global": "global", | |
| "MuLan Only": "mulan_only" | |
| } | |
| norm_value = norm_map.get(normalization, "individual") | |
| if method == "Combined (Both Models)": | |
| # Run both models and combine results | |
| results = tag_audio_combined( | |
| audio_file, | |
| top_k=int(top_k), | |
| normalization=norm_value | |
| ) | |
| tags_str = _extract_combined_tag_string(results) | |
| return format_combined_predictions(results), tags_str | |
| elif method == "Zero-Shot (CLAP + MuLan)": | |
| predictions = tag_audio_zero_shot( | |
| audio_file, | |
| top_k=int(top_k), | |
| normalization=norm_value | |
| ) | |
| tags_str = _extract_tag_string(predictions) | |
| return format_predictions(predictions, "Zero-Shot Tagging"), tags_str | |
| elif method == "MTG-Jamendo Multi-Head (MAEST)": | |
| predictions = tag_audio_multihead(audio_file, top_k=int(top_k)) | |
| tags_str = _extract_tag_string(predictions, is_multihead=True) | |
| return format_multihead_predictions(predictions), tags_str | |
| else: | |
| return f"Unknown method: {method}", "" | |
| except FileNotFoundError as e: | |
| return f"**Error: Missing Required Files**\n\n{str(e)}", "" | |
| except ImportError as e: | |
| return f"**Error: Missing Dependencies**\n\n{str(e)}", "" | |
| except Exception as e: | |
| return f"**Error during analysis:**\n\n{str(e)}", "" | |
| def check_available_methods(): | |
| """Check which methods are available based on existing files.""" | |
| available = [] | |
| messages = [] | |
| # Check Zero-Shot files | |
| clap_exists = (SCRIPT_DIR / config.clap_embeddings_path).exists() | |
| mulan_exists = (SCRIPT_DIR / config.mulan_embeddings_path).exists() | |
| tags_exists = (SCRIPT_DIR / config.tag_names_path).exists() | |
| zero_shot_available = clap_exists and mulan_exists and tags_exists | |
| if zero_shot_available: | |
| available.append("Zero-Shot (CLAP + MuLan)") | |
| else: | |
| missing = [] | |
| if not clap_exists: | |
| missing.append("clap_tag_embeddings.npy") | |
| if not mulan_exists: | |
| missing.append("mulan_tag_embeddings.npy") | |
| if not tags_exists: | |
| missing.append("musiccaps_tag_names.txt") | |
| messages.append(f"Zero-Shot: Missing {', '.join(missing)}") | |
| # Check MTG-Jamendo Multi-Head MAEST checkpoint | |
| multihead_available = (SCRIPT_DIR / config.multihead_checkpoint_path).exists() | |
| if multihead_available: | |
| available.append("MTG-Jamendo Multi-Head (MAEST)") | |
| else: | |
| messages.append(f"Multi-Head MAEST: Missing {config.multihead_checkpoint_path}") | |
| # Add Combined option if both methods are available | |
| if zero_shot_available and multihead_available: | |
| available.insert(0, "Combined (Both Models)") # Insert at beginning as default | |
| # Llama model is available via HuggingFace Hub + transformers | |
| messages.append(f"Llama Review Model: {config.llama_hf_repo} (via transformers)") | |
| return available, messages | |
| def create_interface(): | |
| """Create and configure the Gradio interface.""" | |
| available_methods, status_messages = check_available_methods() | |
| if available_methods: | |
| status = f"**Available methods:** {', '.join(available_methods)}" | |
| else: | |
| status = "**Warning:** No tagging methods available!" | |
| if status_messages: | |
| status += "\n\n**Status:**\n" + "\n".join(f"- {m}" for m in status_messages) | |
| all_methods = [ | |
| "Combined (Both Models)", | |
| "Zero-Shot (CLAP + MuLan)", | |
| "MTG-Jamendo Multi-Head (MAEST)" | |
| ] | |
| # Default to Combined if available, otherwise first available method | |
| if "Combined (Both Models)" in available_methods: | |
| default_method = "Combined (Both Models)" | |
| elif available_methods: | |
| default_method = available_methods[0] | |
| else: | |
| default_method = all_methods[0] | |
| # Llama model is available via HuggingFace Inference API | |
| llama_available = True | |
| with gr.Blocks(title="Stitchfork") as interface: | |
| gr.Markdown(""" | |
| # Stitchfork: AI Music Reviews | |
| Upload a song, get AI-generated tags and a Pitchfork-style review. | |
| **Pipeline:** | |
| 1. **Music Tagging:** Combines CLAP/MuLan (qualitative) + MAEST (genre/instrument/mood) | |
| 2. **Lyric Transcription:** Demucs separates vocals, Whisper transcribes | |
| 3. **Review Generation:** Fine-tuned Llama 3.2 3B generates the review | |
| *Set your desired score and the AI will write a review matching that rating!* | |
| """) | |
| gr.Markdown(status) | |
| # Hidden states | |
| tags_state = gr.State(value="") | |
| lyrics_state = gr.State(value="") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ) | |
| method_dropdown = gr.Dropdown( | |
| choices=all_methods, | |
| value=default_method, | |
| label="Tagging Method", | |
| info="Combined runs both models for comprehensive tags" | |
| ) | |
| normalization_dropdown = gr.Dropdown( | |
| choices=["Individual (recommended)", "Global", "MuLan Only"], | |
| value="Individual (recommended)", | |
| label="Normalization (for Zero-Shot tags)", | |
| info="How to combine CLAP and MuLan scores", | |
| visible=(default_method in ["Zero-Shot (CLAP + MuLan)", "Combined (Both Models)"]) | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=3, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="Tags per Category", | |
| info="How many top tags to show" | |
| ) | |
| include_lyrics_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Transcribe Lyrics", | |
| info="Use Demucs + Whisper to transcribe vocals" | |
| ) | |
| analyze_btn = gr.Button("Analyze Song", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| output_text = gr.Markdown( | |
| label="Tags", | |
| value="Upload an audio file and click 'Analyze' to see predictions." | |
| ) | |
| lyrics_output = gr.Markdown( | |
| label="Transcribed Lyrics", | |
| value="" | |
| ) | |
| # Review Generation Section | |
| gr.Markdown("---") | |
| gr.Markdown("### Generate Review") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| artist_input = gr.Textbox( | |
| label="Artist Name", | |
| placeholder="e.g., Radiohead", | |
| value="Unknown Artist" | |
| ) | |
| title_input = gr.Textbox( | |
| label="Album/Song Title", | |
| placeholder="e.g., OK Computer", | |
| value="Unknown Album" | |
| ) | |
| score_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=10.0, | |
| value=7.0, | |
| step=0.1, | |
| label="Review Score", | |
| info="Set the score for the generated review (affects tone)" | |
| ) | |
| review_btn = gr.Button( | |
| "Generate Review", | |
| variant="secondary", | |
| interactive=llama_available | |
| ) | |
| with gr.Column(scale=2): | |
| review_output = gr.Markdown( | |
| value="Analyze a song first, then generate a review." if llama_available | |
| else "**Review generation unavailable:** Llama model not found." | |
| ) | |
| # Event handlers | |
| def update_normalization_visibility(method): | |
| return gr.update(visible=(method in ["Zero-Shot (CLAP + MuLan)", "Combined (Both Models)"])) | |
| method_dropdown.change( | |
| fn=update_normalization_visibility, | |
| inputs=[method_dropdown], | |
| outputs=[normalization_dropdown] | |
| ) | |
| def analyze_with_lyrics_wrapper(audio, method, norm, topk, include_lyrics): | |
| tags_result, tags_str, lyrics_text, vocal_detected = analyze_audio_with_lyrics( | |
| audio, method, norm, topk, include_lyrics | |
| ) | |
| if lyrics_text: | |
| if lyrics_text.startswith("["): | |
| lyrics_display = f"**Lyrics:** {lyrics_text}" | |
| else: | |
| lyrics_display = f"## Transcribed Lyrics\n\n{lyrics_text}" | |
| else: | |
| lyrics_display = "" | |
| return tags_result, tags_str, lyrics_text, lyrics_display | |
| analyze_btn.click( | |
| fn=analyze_with_lyrics_wrapper, | |
| inputs=[audio_input, method_dropdown, normalization_dropdown, | |
| top_k_slider, include_lyrics_checkbox], | |
| outputs=[output_text, tags_state, lyrics_state, lyrics_output], | |
| show_progress="full" | |
| ) | |
| def generate_review_wrapper(tags_str, lyrics_str, score, artist, title): | |
| if not tags_str: | |
| return "Analyze a song first to generate tags, then click Generate Review." | |
| return generate_review_llama(tags_str, lyrics_str, score, artist, title) | |
| review_btn.click( | |
| fn=generate_review_wrapper, | |
| inputs=[tags_state, lyrics_state, score_slider, artist_input, title_input], | |
| outputs=[review_output], | |
| show_progress="full" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **About:** | |
| This app uses a fine-tuned Llama 3.2 3B model trained on 18,000 Pitchfork reviews | |
| to generate music criticism. The score you set influences the tone - low scores | |
| produce critical reviews, high scores produce enthusiastic praise. | |
| """) | |
| return interface | |
| # ============================================================ | |
| # Main Entry Point | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| print("=" * 60) | |
| print("Stitchfork: AI Music Reviews") | |
| print("=" * 60) | |
| print(f"Device: {config.device}") | |
| print(f"Script directory: {SCRIPT_DIR}") | |
| available, messages = check_available_methods() | |
| print(f"\nAvailable methods: {available if available else 'None'}") | |
| if messages: | |
| print("Status:") | |
| for msg in messages: | |
| print(f" - {msg}") | |
| print("\nLaunching Gradio interface...") | |
| interface = create_interface() | |
| interface.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| ) |