"""Automatic audio annotation for bulk dataset creation. Two tiers: - basic: librosa-only DSP (tempo, key). No downloads. CPU. ~instant per file. - rich: basic + LAION-CLAP zero-shot tagging (genre, mood, instrument). Lazy-loaded; downloads ~2.35 GB checkpoint on first use. """ from __future__ import annotations import json import logging import os import threading from pathlib import Path from typing import Any, Callable, Dict, List, Optional logger = logging.getLogger(__name__) AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac", ".m4a", ".ogg", ".aac") CLAP_CKPT_FILENAME = "music_audioset_epoch_15_esc_90.14.pt" CLAP_REPO = "lukewys/laion_clap" KEY_NAMES_SHARP = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] KEY_NAMES_FLAT = ["C", "Db", "D", "Eb", "E", "F", "Gb", "G", "Ab", "A", "Bb", "B"] # Krumhansl-Schmuckler key profiles. KRUMHANSL_MAJOR = [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88] KRUMHANSL_MINOR = [6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17] def _iter_audio_files(folder: Path) -> List[Path]: results: List[Path] = [] for root, _, files in os.walk(folder): for name in files: if name.startswith("."): continue if name.lower().endswith(AUDIO_EXTENSIONS): results.append(Path(root) / name) results.sort() return results def _estimate_tempo(y, sr) -> Optional[int]: import librosa try: tempo, _ = librosa.beat.beat_track(y=y, sr=sr) bpm = float(tempo if hasattr(tempo, "__float__") else tempo[0]) if bpm <= 0: return None return int(round(bpm)) except Exception as exc: logger.debug("tempo estimation failed: %s", exc) return None def _estimate_brightness(y, sr) -> Optional[str]: import librosa try: centroid = float(librosa.feature.spectral_centroid(y=y, sr=sr).mean()) except Exception as exc: logger.debug("centroid estimation failed: %s", exc) return None if centroid <= 0: return None if centroid < 1500: return "dark" if centroid > 3500: return "bright" return None def _estimate_character(y, sr) -> Optional[str]: import librosa import numpy as np try: harm, perc = librosa.effects.hpss(y) eh = float(np.mean(harm ** 2)) ep = float(np.mean(perc ** 2)) except Exception as exc: logger.debug("HPSS failed: %s", exc) return None total = eh + ep if total <= 0: return None perc_ratio = ep / total if perc_ratio > 0.65: return "percussion-driven" if perc_ratio < 0.20: return "melodic" return None def _estimate_key(y, sr) -> Optional[str]: import librosa import numpy as np try: chroma = librosa.feature.chroma_cqt(y=y, sr=sr) chroma_mean = chroma.mean(axis=1) if chroma_mean.sum() <= 0: return None chroma_mean = chroma_mean / chroma_mean.sum() major = np.asarray(KRUMHANSL_MAJOR) minor = np.asarray(KRUMHANSL_MINOR) best_score = -1.0 best_key = None for i in range(12): maj_score = float(np.corrcoef(chroma_mean, np.roll(major, i))[0, 1]) min_score = float(np.corrcoef(chroma_mean, np.roll(minor, i))[0, 1]) if maj_score > best_score: best_score = maj_score best_key = f"{KEY_NAMES_SHARP[i]} major" if min_score > best_score: best_score = min_score best_key = f"{KEY_NAMES_SHARP[i]} minor" return best_key except Exception as exc: logger.debug("key estimation failed: %s", exc) return None def _compose_prompt(parts: Dict[str, Any]) -> str: genre = parts.get("genre") mood = parts.get("mood") instruments = parts.get("instruments") or [] bpm = parts.get("bpm") key = parts.get("key") brightness = parts.get("brightness") character = parts.get("character") head_bits: List[str] = [] if mood: head_bits.append(str(mood)) if genre: head_bits.append(f"{genre} track") elif head_bits: head_bits[-1] = f"{head_bits[-1]} track" opening = " ".join(head_bits) descriptors = [d for d in (brightness, character) if d] fragments: List[str] = [] if opening: fragments.append(opening) if descriptors: fragments.append(", ".join(descriptors)) if bpm: fragments.append(f"{bpm} BPM") if key: fragments.append(f"in {key}") if instruments: fragments.append("with " + ", ".join(instruments)) out = ", ".join(fragments) return out[:1].upper() + out[1:] if out else "" class _ClapTagger: """Lazy holder for a LAION-CLAP model used for zero-shot tagging.""" def __init__(self, ckpt_path: Path): self.ckpt_path = ckpt_path self._model = None self._lock = threading.Lock() self._label_embeds: Dict[str, Any] = {} def ensure_loaded(self): if self._model is not None: return with self._lock: if self._model is not None: return if not self.ckpt_path.exists(): raise FileNotFoundError( f"CLAP checkpoint not found at {self.ckpt_path}. " "Download it first via /api/bulk-annotate/download-clap." ) import laion_clap import torch logging.getLogger("transformers").setLevel(logging.ERROR) device = "cuda" if torch.cuda.is_available() else "cpu" model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-base", device=device) # torch >= 2.6 flipped torch.load(weights_only=True) and newer # transformers dropped the roberta position_ids buffer, so # laion_clap's own load_ckpt errors twice: unpickling, then strict # state_dict mismatch. Replicate its logic safely here. from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict orig_load = torch.load def _trusted_load(*args, **kwargs): kwargs.setdefault("weights_only", False) return orig_load(*args, **kwargs) torch.load = _trusted_load try: state = clap_load_state_dict(str(self.ckpt_path), skip_params=True) finally: torch.load = orig_load missing, unexpected = model.model.load_state_dict(state, strict=False) if unexpected: logger.debug("CLAP unexpected keys ignored: %s", unexpected[:5]) if missing: logger.debug("CLAP missing keys: %s", missing[:5]) self._model = model self._device = device logger.info("CLAP loaded on %s from %s", device, self.ckpt_path) def _embed_labels(self, group: str, prompts: List[str]): import torch key = f"{group}:{'|'.join(prompts)}" if key in self._label_embeds: return self._label_embeds[key] with torch.no_grad(): embed = self._model.get_text_embedding(prompts, use_tensor=True) embed = embed / embed.norm(dim=-1, keepdim=True).clamp_min(1e-8) self._label_embeds[key] = embed return embed def tag(self, audio_path: Path, label_sets: Dict[str, List[str]], top_k_instruments: int = 2) -> Dict[str, Any]: self.ensure_loaded() import torch with torch.no_grad(): audio_embed = self._model.get_audio_embedding_from_filelist( x=[str(audio_path)], use_tensor=True ) audio_embed = audio_embed / audio_embed.norm(dim=-1, keepdim=True).clamp_min(1e-8) out: Dict[str, Any] = {} for group in ("genre", "mood"): labels = label_sets.get(group) or [] if not labels: continue prompts = [f"a {lab} music track" if group == "genre" else f"a {lab} sounding music track" for lab in labels] text_embed = self._embed_labels(group, prompts) sims = (audio_embed @ text_embed.T).squeeze(0) top = int(sims.argmax().item()) out[group] = labels[top] instruments = label_sets.get("instruments") or [] if instruments: prompts = [f"music featuring {lab}" for lab in instruments] text_embed = self._embed_labels("instruments", prompts) sims = (audio_embed @ text_embed.T).squeeze(0) k = min(top_k_instruments, len(instruments)) top_idx = torch.topk(sims, k=k).indices.tolist() out["instruments"] = [instruments[i] for i in top_idx] return out _clap_tagger_singleton: Optional[_ClapTagger] = None _clap_tagger_lock = threading.Lock() def get_clap_tagger(clap_ckpt_path: Path) -> _ClapTagger: global _clap_tagger_singleton with _clap_tagger_lock: if _clap_tagger_singleton is None or _clap_tagger_singleton.ckpt_path != clap_ckpt_path: _clap_tagger_singleton = _ClapTagger(clap_ckpt_path) return _clap_tagger_singleton def clap_checkpoint_path(models_pretrained_dir: Path) -> Path: return models_pretrained_dir / "clap" / CLAP_CKPT_FILENAME def clap_checkpoint_available(models_pretrained_dir: Path) -> bool: return clap_checkpoint_path(models_pretrained_dir).exists() def download_clap_checkpoint( models_pretrained_dir: Path, progress_cb: Optional[Callable[[str], None]] = None, ) -> Path: target = clap_checkpoint_path(models_pretrained_dir) target.parent.mkdir(parents=True, exist_ok=True) if target.exists(): return target from huggingface_hub import hf_hub_download import os if progress_cb: progress_cb("Downloading CLAP checkpoint (~630 MB)…") # Use custom CLAP from fragmenta-models on HF Spaces use_custom_repo = os.getenv('FRAGMENTA_USE_CUSTOM_MODELS', '').lower() == 'true' if use_custom_repo: repo_id = "MazCodes/fragmenta-models" else: repo_id = CLAP_REPO downloaded = hf_hub_download( repo_id=repo_id, filename=CLAP_CKPT_FILENAME, local_dir=str(target.parent), ) downloaded_path = Path(downloaded) if downloaded_path != target: try: downloaded_path.replace(target) except OSError: import shutil shutil.copy2(downloaded_path, target) return target def load_label_sets(label_sets_path: Optional[Path]) -> Dict[str, List[str]]: if not label_sets_path or not label_sets_path.exists(): return {"genre": [], "mood": [], "instruments": []} with open(label_sets_path, "r", encoding="utf-8") as f: data = json.load(f) return { "genre": list(data.get("genre") or []), "mood": list(data.get("mood") or []), "instruments": list(data.get("instruments") or []), } def annotate_file( audio_path: Path, tier: str, clap_tagger: Optional[_ClapTagger], label_sets: Dict[str, List[str]], sr: int = 22050, max_seconds: float = 60.0, ) -> Dict[str, Any]: import librosa parts: Dict[str, Any] = {} try: y, loaded_sr = librosa.load(str(audio_path), sr=sr, mono=True, duration=max_seconds) except Exception as exc: logger.warning("librosa failed to load %s: %s", audio_path.name, exc) return { "file_name": audio_path.name, "prompt": "", "path": str(audio_path), "error": f"load failed: {exc}", } parts["bpm"] = _estimate_tempo(y, loaded_sr) parts["key"] = _estimate_key(y, loaded_sr) parts["brightness"] = _estimate_brightness(y, loaded_sr) parts["character"] = _estimate_character(y, loaded_sr) if tier == "rich" and clap_tagger is not None: try: tags = clap_tagger.tag(audio_path, label_sets) parts.update(tags) except Exception as exc: logger.warning("CLAP tagging failed for %s: %s", audio_path.name, exc) prompt = _compose_prompt(parts) return { "file_name": audio_path.name, "prompt": prompt, "path": str(audio_path), "attributes": parts, } def annotate_folder( folder: Path, tier: str, label_sets: Dict[str, List[str]], clap_ckpt_path: Optional[Path] = None, progress_cb: Optional[Callable[[int, int, str], None]] = None, ) -> List[Dict[str, Any]]: folder = Path(folder) if not folder.exists() or not folder.is_dir(): raise ValueError(f"Folder not found: {folder}") files = _iter_audio_files(folder) if not files: raise ValueError(f"No audio files found in {folder}") clap_tagger: Optional[_ClapTagger] = None if tier == "rich": if not clap_ckpt_path or not Path(clap_ckpt_path).exists(): raise FileNotFoundError( "Rich tier requires the CLAP checkpoint; download it first." ) clap_tagger = get_clap_tagger(Path(clap_ckpt_path)) clap_tagger.ensure_loaded() results: List[Dict[str, Any]] = [] total = len(files) for i, audio_path in enumerate(files, start=1): if progress_cb: progress_cb(i, total, audio_path.name) entry = annotate_file(audio_path, tier, clap_tagger, label_sets) results.append(entry) return results def unload_clap(): """Free CLAP weights from VRAM. Call before training starts.""" global _clap_tagger_singleton with _clap_tagger_lock: if _clap_tagger_singleton is not None: _clap_tagger_singleton._model = None _clap_tagger_singleton._label_embeds = {} _clap_tagger_singleton = None try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass