Spaces:
Running
Running
| """ | |
| AutoMixAI β HuggingFace Space Backend (All-in-One) | |
| Consolidated FastAPI backend hosting ALL services: | |
| - /upload Upload audio files (stored temporarily) | |
| - /analyze BPM + beat detection + energy analysis | |
| - /mix Advanced DJ mixing with EQ crossfade | |
| - /generate Procedural drum beat generation | |
| - /output/{id} Download generated files | |
| - /recognize Song recognition via Shazam API | |
| - /health Health check | |
| Advanced DJ Mixing Features: | |
| - LUFS loudness normalization (-24 LUFS pre-mix, -14 LUFS final) | |
| - High-pass filtering (40Hz rumble removal) | |
| - Beat detection + BPM estimation | |
| - Pitch-preserving time-stretching to common BPM | |
| - Beat-aligned trimming | |
| - EQ-based crossfade (DJ-style bass swap transition) | |
| - Equal-power S-curve crossfade | |
| - Per-track EQ: bass boost, brightness, vocal boost | |
| - Stereo panning | |
| - Final LUFS mastering to streaming standard | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import uuid | |
| import tempfile | |
| import subprocess | |
| import re | |
| import math | |
| from pathlib import Path | |
| from typing import Optional, Dict, List, Tuple, Any | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| import numpy as np | |
| import librosa | |
| import soundfile as sf | |
| import scipy.signal | |
| import pyloudnorm as pyln | |
| import torch | |
| from transformers import pipeline as hf_pipeline, AutoFeatureExtractor, AutoModelForAudioClassification | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel, Field | |
| import httpx | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SR = 22050 # librosa default for analysis | |
| SR_MIX = 44100 # output sample rate for mixing | |
| HOP_LENGTH = 512 | |
| TARGET_LOUDNESS = -24.0 # pre-mix normalization (LUFS) | |
| FINAL_LOUDNESS = -14.0 # streaming-standard final loudness (LUFS) | |
| UPLOAD_DIR = Path(tempfile.gettempdir()) / "automixai_uploads" | |
| OUTPUT_DIR = Path(tempfile.gettempdir()) / "automixai_outputs" | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| RAPIDAPI_KEY = os.environ.get("RAPIDAPI_KEY", "") | |
| SHAZAM_URL = "https://shazam-core.p.rapidapi.com/v1/tracks/recognize" | |
| SHAZAM_HOST = "shazam-core.p.rapidapi.com" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FASTAPI APP | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="AutoMixAI Backend", | |
| description="AI-powered DJ mixing, beat generation, audio analysis, and song recognition.", | |
| version="2.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SCHEMAS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class UploadResponse(BaseModel): | |
| file_id: str | |
| filename: str | |
| duration: float | |
| message: str = "File uploaded successfully." | |
| class AnalyzeRequest(BaseModel): | |
| file_id: str | |
| class AnalysisResponse(BaseModel): | |
| file_id: str | |
| bpm: float | |
| beat_times: list[float] | |
| duration: float | |
| sample_rate: int | |
| genre: str = "unknown" | |
| genre_confidence: float = 0.0 | |
| genre_top3: list = [] | |
| tags: list[str] = [] | |
| tag_scores: list = [] | |
| mood: str = "neutral" | |
| has_vocals: bool = False | |
| dominant_instrument: str = "unknown" | |
| instrument_confidence: float = 0.0 | |
| instruments_top3: list = [] | |
| energy: str = "medium" | |
| message: str = "Analysis complete." | |
| class MixRequest(BaseModel): | |
| file_id_a: str | |
| file_id_b: str | |
| crossfade_duration: float = 8.0 | |
| bass_boost: float = 0.0 | |
| brightness: float = 0.0 | |
| vocal_boost: float = 0.0 | |
| pan_a: float = 0.0 | |
| pan_b: float = 0.0 | |
| eq_transition: bool = True # Use EQ-based DJ crossfade | |
| class MixResponse(BaseModel): | |
| output_file_id: str | |
| duration: float | |
| bpm_a: float | |
| bpm_b: float | |
| target_bpm: float | |
| message: str = "Mix generated successfully." | |
| class GenerateBeatRequest(BaseModel): | |
| prompt: str = Field(..., min_length=3, max_length=500) | |
| bars: int = Field(default=4, ge=1, le=32) | |
| class PatternInfo(BaseModel): | |
| kick: list[int] | |
| snare: list[int] | |
| hihat_c: list[int] | |
| hihat_o: list[int] | |
| clap: list[int] | |
| class GenerateBeatResponse(BaseModel): | |
| output_file_id: str | |
| genre: str | |
| bpm: float | |
| bars: int | |
| complexity: str | |
| description: str | |
| duration: float | |
| pattern: PatternInfo | |
| sample_rate: int = 44100 | |
| message: str = "Beat generated successfully." | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_file_id() -> str: | |
| return uuid.uuid4().hex | |
| def find_upload(file_id: str) -> Path: | |
| for path in UPLOAD_DIR.iterdir(): | |
| if path.stem == file_id: | |
| return path | |
| raise FileNotFoundError(f"No file found for ID '{file_id}'") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AUDIO LOADING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_audio(path: str, sr: int = None, mono: bool = True): | |
| sr = sr or SR | |
| y, loaded_sr = librosa.load(path, sr=sr, mono=mono) | |
| peak = np.max(np.abs(y)) | |
| if peak > 0: | |
| y = y / peak | |
| return y, loaded_sr | |
| def load_audio_full_rate(path: str, mono: bool = True): | |
| """Load audio at native sample rate (for mixing quality).""" | |
| y, sr = librosa.load(path, sr=SR_MIX, mono=mono) | |
| return y, sr | |
| def get_audio_info(path: str) -> dict: | |
| try: | |
| info = sf.info(path) | |
| return {"duration": info.duration, "sample_rate": info.samplerate, "channels": info.channels} | |
| except Exception: | |
| y, sr = librosa.load(path, sr=None, mono=False) | |
| channels = 1 if y.ndim == 1 else y.shape[0] | |
| duration = (len(y) if y.ndim == 1 else y.shape[1]) / sr | |
| return {"duration": float(duration), "sample_rate": int(sr), "channels": int(channels)} | |
| def save_audio(path: str, audio: np.ndarray, sr: int): | |
| sf.write(path, audio, sr, subtype="PCM_16") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BEAT DETECTION + BPM | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_beats(y: np.ndarray, sr: int) -> list[float]: | |
| tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr, hop_length=HOP_LENGTH) | |
| beat_times = librosa.frames_to_time(beat_frames, sr=sr, hop_length=HOP_LENGTH).tolist() | |
| return beat_times | |
| def estimate_bpm_from_beats(beat_times: list[float]) -> float: | |
| if len(beat_times) < 2: | |
| return 120.0 | |
| ibis = np.diff(beat_times) | |
| median_ibi = float(np.median(ibis)) | |
| if median_ibi <= 0: | |
| return 120.0 | |
| return round(60.0 / median_ibi, 2) | |
| def estimate_bpm_librosa(y: np.ndarray, sr: int) -> float: | |
| tempo, _ = librosa.beat.beat_track(y=y, sr=sr, hop_length=HOP_LENGTH) | |
| return round(float(np.asarray(tempo).flat[0]), 2) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ADVANCED DJ MIXING ENGINE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def normalize_loudness(audio: np.ndarray, sr: int, target_lufs: float) -> np.ndarray: | |
| """LUFS loudness normalization (EBU R128).""" | |
| try: | |
| meter = pyln.Meter(sr) | |
| current_loudness = meter.integrated_loudness(audio) | |
| if np.isinf(current_loudness) or np.isnan(current_loudness): | |
| return audio | |
| return pyln.normalize.loudness(audio, current_loudness, target_lufs) | |
| except Exception: | |
| return audio | |
| def highpass_filter(audio: np.ndarray, sr: int, cutoff: float = 40.0) -> np.ndarray: | |
| """Remove sub-bass rumble below cutoff frequency.""" | |
| sos = scipy.signal.butter(2, cutoff, btype='highpass', fs=sr, output='sos') | |
| return scipy.signal.sosfilt(sos, audio).astype(np.float32) | |
| def lowpass_filter(audio: np.ndarray, sr: int, cutoff: float = 200.0) -> np.ndarray: | |
| """Extract bass frequencies.""" | |
| sos = scipy.signal.butter(2, cutoff, btype='low', fs=sr, output='sos') | |
| return scipy.signal.sosfilt(sos, audio).astype(np.float32) | |
| def highpass_extract(audio: np.ndarray, sr: int, cutoff: float = 200.0) -> np.ndarray: | |
| """Extract mid+high frequencies (above cutoff).""" | |
| sos = scipy.signal.butter(2, cutoff, btype='high', fs=sr, output='sos') | |
| return scipy.signal.sosfilt(sos, audio).astype(np.float32) | |
| def apply_bass_boost(audio: np.ndarray, sr: int, amount: float) -> np.ndarray: | |
| """Boost low frequencies (below 150Hz).""" | |
| if amount <= 0: | |
| return audio | |
| sos = scipy.signal.butter(2, 150, btype='low', fs=sr, output='sos') | |
| bass = scipy.signal.sosfilt(sos, audio) | |
| return (audio + amount * bass).astype(np.float32) | |
| def apply_brightness(audio: np.ndarray, sr: int, amount: float) -> np.ndarray: | |
| """Boost high frequencies (above 4kHz).""" | |
| if amount <= 0: | |
| return audio | |
| sos = scipy.signal.butter(2, 4000, btype='high', fs=sr, output='sos') | |
| highs = scipy.signal.sosfilt(sos, audio) | |
| return (audio + amount * highs).astype(np.float32) | |
| def pan_stereo(mono_audio: np.ndarray, pan: float) -> np.ndarray: | |
| """Pan mono audio to stereo. pan: -1 (left) to 1 (right), 0 = center.""" | |
| left_gain = np.cos(max(0, pan) * np.pi / 2) | |
| right_gain = np.cos(max(0, -pan) * np.pi / 2) | |
| return np.vstack((mono_audio * left_gain, mono_audio * right_gain)) | |
| def time_stretch_to_bpm(y: np.ndarray, sr: int, original_bpm: float, target_bpm: float) -> np.ndarray: | |
| """Pitch-preserving time-stretch to match target BPM.""" | |
| rate = target_bpm / original_bpm | |
| if abs(rate - 1.0) < 0.005: | |
| return y | |
| return librosa.effects.time_stretch(y, rate=rate) | |
| def align_to_beat(y: np.ndarray, sr: int, beat_times: list[float]) -> np.ndarray: | |
| """Trim audio to start on the first beat.""" | |
| if not beat_times: | |
| return y | |
| start_sample = int(beat_times[0] * sr) | |
| return y[start_sample:] | |
| def equal_power_crossfade(y1: np.ndarray, y2: np.ndarray, sr: int, duration: float) -> np.ndarray: | |
| """ | |
| Equal-power (S-curve) crossfade β psychoacoustically smooth. | |
| Uses sin/cos curves instead of linear, preventing the volume dip | |
| that occurs with linear crossfades. | |
| """ | |
| fade_samples = int(duration * sr) | |
| fade_samples = min(fade_samples, len(y1), len(y2)) | |
| if fade_samples < 1: | |
| return np.concatenate([y1, y2]) | |
| # Equal-power curves: cosΒ² + sinΒ² = 1 (constant power) | |
| t = np.linspace(0, np.pi / 2, fade_samples) | |
| fade_out = np.cos(t) ** 2 # smooth fade out for track A | |
| fade_in = np.sin(t) ** 2 # smooth fade in for track B | |
| tail = y1[-fade_samples:] * fade_out | |
| head = y2[:fade_samples] * fade_in | |
| mixed_region = tail + head | |
| return np.concatenate([y1[:-fade_samples], mixed_region, y2[fade_samples:]]) | |
| def eq_crossfade(y1: np.ndarray, y2: np.ndarray, sr: int, duration: float) -> np.ndarray: | |
| """ | |
| Professional DJ-style EQ crossfade transition. | |
| Instead of simply fading volumes, this mimics what real DJs do: | |
| 1. Fade out Track A's BASS while fading in Track B's BASS | |
| 2. Crossfade mid+high frequencies with an S-curve | |
| 3. Prevents muddiness and bass clashing during transitions | |
| This is the core technique used in professional DJ sets. | |
| """ | |
| fade_samples = int(duration * sr) | |
| fade_samples = min(fade_samples, len(y1), len(y2)) | |
| if fade_samples < 1: | |
| return np.concatenate([y1, y2]) | |
| # Split both tracks into bass (< 200Hz) and mid+high (> 200Hz) | |
| y1_bass = lowpass_filter(y1[-fade_samples:], sr, cutoff=200) | |
| y1_mid_high = highpass_extract(y1[-fade_samples:], sr, cutoff=200) | |
| y2_bass = lowpass_filter(y2[:fade_samples], sr, cutoff=200) | |
| y2_mid_high = highpass_extract(y2[:fade_samples], sr, cutoff=200) | |
| # Equal-power curves | |
| t = np.linspace(0, np.pi / 2, fade_samples) | |
| fade_out = np.cos(t) ** 2 | |
| fade_in = np.sin(t) ** 2 | |
| # Bass: sharper transition (prevent bass clash) | |
| # The bass swaps over more aggressively in the middle | |
| bass_t = np.linspace(0, np.pi / 2, fade_samples) | |
| bass_fade_out = np.cos(bass_t) ** 3 # sharper cut | |
| bass_fade_in = np.sin(bass_t) ** 3 # sharper rise | |
| # Mix the bass swap + mid/high crossfade | |
| bass_region = y1_bass * bass_fade_out + y2_bass * bass_fade_in | |
| mid_high_region = y1_mid_high * fade_out + y2_mid_high * fade_in | |
| mixed_region = bass_region + mid_high_region | |
| return np.concatenate([y1[:-fade_samples], mixed_region, y2[fade_samples:]]) | |
| def create_advanced_mix( | |
| path_a: str, | |
| path_b: str, | |
| output_path: str, | |
| crossfade_duration: float = 8.0, | |
| bass_boost: float = 0.0, | |
| brightness: float = 0.0, | |
| vocal_boost: float = 0.0, | |
| pan_a: float = 0.0, | |
| pan_b: float = 0.0, | |
| eq_transition: bool = True, | |
| ) -> dict: | |
| """ | |
| Advanced DJ mixing pipeline. | |
| Steps: | |
| 1. Load both tracks at 44.1kHz | |
| 2. LUFS normalize to -24 LUFS | |
| 3. High-pass filter at 40Hz (remove rumble) | |
| 4. Detect beats + estimate BPM | |
| 5. Time-stretch both to a common BPM (average) | |
| 6. Align to first beat boundary | |
| 7. Apply per-track EQ (bass boost, brightness) | |
| 8. EQ-based crossfade or equal-power crossfade | |
| 9. Final LUFS mastering to -14 LUFS | |
| """ | |
| print(f"=== Creating advanced DJ mix ===") | |
| print(f"Track A: {path_a}") | |
| print(f"Track B: {path_b}") | |
| # 1. Load at high quality | |
| y_a, sr = load_audio_full_rate(path_a) | |
| y_b, _ = load_audio_full_rate(path_b) | |
| print(f"Loaded: A={len(y_a)/sr:.1f}s, B={len(y_b)/sr:.1f}s at {sr}Hz") | |
| # 2. LUFS normalize | |
| y_a = normalize_loudness(y_a, sr, TARGET_LOUDNESS) | |
| y_b = normalize_loudness(y_b, sr, TARGET_LOUDNESS) | |
| print("LUFS normalized to -24 LUFS") | |
| # 3. High-pass filter | |
| y_a = highpass_filter(y_a, sr, cutoff=40) | |
| y_b = highpass_filter(y_b, sr, cutoff=40) | |
| # 4. Beat detection + BPM (use lower SR for speed, then apply to full audio) | |
| y_a_analysis = librosa.resample(y_a, orig_sr=sr, target_sr=SR) | |
| y_b_analysis = librosa.resample(y_b, orig_sr=sr, target_sr=SR) | |
| beats_a = detect_beats(y_a_analysis, SR) | |
| beats_b = detect_beats(y_b_analysis, SR) | |
| bpm_a = estimate_bpm_from_beats(beats_a) if len(beats_a) >= 2 else estimate_bpm_librosa(y_a_analysis, SR) | |
| bpm_b = estimate_bpm_from_beats(beats_b) if len(beats_b) >= 2 else estimate_bpm_librosa(y_b_analysis, SR) | |
| target_bpm = round((bpm_a + bpm_b) / 2, 2) | |
| print(f"BPMs β A: {bpm_a:.1f}, B: {bpm_b:.1f} β target: {target_bpm:.1f}") | |
| # 5. Time-stretch to common BPM | |
| y_a = time_stretch_to_bpm(y_a, sr, bpm_a, target_bpm) | |
| y_b = time_stretch_to_bpm(y_b, sr, bpm_b, target_bpm) | |
| print("Time-stretched to common BPM") | |
| # Re-detect beats after stretch for alignment | |
| y_a_stretched_analysis = librosa.resample(y_a, orig_sr=sr, target_sr=SR) | |
| y_b_stretched_analysis = librosa.resample(y_b, orig_sr=sr, target_sr=SR) | |
| beats_a_new = detect_beats(y_a_stretched_analysis, SR) | |
| beats_b_new = detect_beats(y_b_stretched_analysis, SR) | |
| # 6. Align to beat boundaries | |
| y_a = align_to_beat(y_a, sr, beats_a_new) | |
| y_b = align_to_beat(y_b, sr, beats_b_new) | |
| # 7. Apply per-track EQ | |
| y_a = apply_bass_boost(y_a, sr, bass_boost) | |
| y_b = apply_bass_boost(y_b, sr, bass_boost) | |
| y_a = apply_brightness(y_a, sr, brightness) | |
| y_b = apply_brightness(y_b, sr, brightness) | |
| # Vocal boost (boost mid frequencies 1-4kHz) | |
| if vocal_boost > 0: | |
| sos = scipy.signal.butter(2, [1000, 4000], btype='band', fs=sr, output='sos') | |
| y_a_vocal = scipy.signal.sosfilt(sos, y_a) | |
| y_b_vocal = scipy.signal.sosfilt(sos, y_b) | |
| y_a = y_a + vocal_boost * y_a_vocal | |
| y_b = y_b + vocal_boost * y_b_vocal | |
| # 8. Crossfade | |
| if eq_transition: | |
| mixed = eq_crossfade(y_a, y_b, sr, duration=crossfade_duration) | |
| print(f"EQ crossfade applied: {crossfade_duration:.1f}s") | |
| else: | |
| mixed = equal_power_crossfade(y_a, y_b, sr, duration=crossfade_duration) | |
| print(f"Equal-power crossfade applied: {crossfade_duration:.1f}s") | |
| # 9. Final LUFS mastering | |
| mixed = normalize_loudness(mixed, sr, FINAL_LOUDNESS) | |
| print("Final mastering to -14 LUFS") | |
| # Soft limiter to prevent clipping | |
| peak = np.max(np.abs(mixed)) | |
| if peak > 0.99: | |
| mixed = mixed * (0.99 / peak) | |
| # Save | |
| save_audio(output_path, mixed, sr) | |
| duration = float(len(mixed) / sr) | |
| print(f"Mix complete: {duration:.1f}s saved to {output_path}") | |
| return { | |
| "bpm_a": bpm_a, | |
| "bpm_b": bpm_b, | |
| "target_bpm": target_bpm, | |
| "duration": duration, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ML-POWERED AUDIO CLASSIFICATION (Genre, Mood, Vocals) | |
| # Uses HuggingFace transformer models for accurate classification | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model IDs | |
| GENRE_MODEL_ID = "dima806/music_genres_classification" # wav2vec2, GTZAN 10 genres | |
| MOOD_MODEL_ID = "StanislavKo28/music_moods_classification" # wav2vec2, 14 moods | |
| # Lazy-loaded classifiers (loaded on first use to speed up startup) | |
| _genre_classifier = None | |
| _mood_classifier = None | |
| def _get_genre_classifier(): | |
| """Lazy-load the genre classification pipeline.""" | |
| global _genre_classifier | |
| if _genre_classifier is None: | |
| print(f"Loading genre model: {GENRE_MODEL_ID}") | |
| _genre_classifier = hf_pipeline( | |
| "audio-classification", | |
| model=GENRE_MODEL_ID, | |
| device=-1, # CPU | |
| ) | |
| print("Genre model loaded.") | |
| return _genre_classifier | |
| def _get_mood_classifier(): | |
| """Lazy-load the mood classification pipeline.""" | |
| global _mood_classifier | |
| if _mood_classifier is None: | |
| print(f"Loading mood model: {MOOD_MODEL_ID}") | |
| _mood_classifier = hf_pipeline( | |
| "audio-classification", | |
| model=MOOD_MODEL_ID, | |
| device=-1, # CPU | |
| ) | |
| print("Mood model loaded.") | |
| return _mood_classifier | |
| def classify_genre(file_path: str) -> dict: | |
| """ | |
| Classify music genre using dima806/music_genres_classification. | |
| Returns top genre + confidence + top 3 predictions. | |
| Genres: blues, classical, country, disco, hiphop, jazz, metal, pop, reggae, rock | |
| """ | |
| try: | |
| classifier = _get_genre_classifier() | |
| results = classifier(file_path, top_k=5) | |
| top = results[0] | |
| top3 = [ | |
| {"genre": r["label"], "confidence": round(r["score"], 4)} | |
| for r in results[:3] | |
| ] | |
| return { | |
| "genre": top["label"], | |
| "confidence": round(top["score"], 4), | |
| "top3": top3, | |
| } | |
| except Exception as e: | |
| print(f"Genre classification error: {e}") | |
| return {"genre": "unknown", "confidence": 0.0, "top3": []} | |
| def classify_mood(file_path: str) -> dict: | |
| """ | |
| Classify music mood using StanislavKo28/music_moods_classification. | |
| Returns top mood + confidence. | |
| Moods: angry, dark, energetic, epic, euphoric, happy, mysterious, | |
| relaxing, romantic, sad, scary, glamorous, uplifting, sentimental | |
| """ | |
| try: | |
| classifier = _get_mood_classifier() | |
| results = classifier(file_path, top_k=5) | |
| top = results[0] | |
| top3 = [ | |
| {"mood": r["label"], "confidence": round(r["score"], 4)} | |
| for r in results[:3] | |
| ] | |
| return { | |
| "mood": top["label"], | |
| "confidence": round(top["score"], 4), | |
| "top3": top3, | |
| } | |
| except Exception as e: | |
| print(f"Mood classification error: {e}") | |
| return {"mood": "neutral", "confidence": 0.0, "top3": []} | |
| def detect_vocals(y: np.ndarray, sr: int) -> dict: | |
| """ | |
| Detect whether the audio has vocals using harmonic/percussive separation. | |
| If the harmonic component has strong mid-frequency energy (1-4kHz vocal range), | |
| the track likely has vocals. | |
| """ | |
| try: | |
| # Separate harmonic and percussive components | |
| y_harmonic, y_percussive = librosa.effects.hpss(y) | |
| # Compute spectral centroid of harmonic part | |
| harmonic_centroid = float(np.mean(librosa.feature.spectral_centroid(y=y_harmonic, sr=sr))) | |
| # Compute energy in vocal frequency range (1-4kHz) | |
| S = np.abs(librosa.stft(y_harmonic)) | |
| freqs = librosa.fft_frequencies(sr=sr) | |
| vocal_mask = (freqs >= 1000) & (freqs <= 4000) | |
| full_mask = freqs >= 80 | |
| vocal_energy = float(np.mean(S[vocal_mask, :])) if np.any(vocal_mask) else 0 | |
| total_energy = float(np.mean(S[full_mask, :])) if np.any(full_mask) else 1 | |
| vocal_ratio = vocal_energy / max(total_energy, 1e-8) | |
| # Harmonic-to-percussive ratio | |
| h_energy = float(np.mean(y_harmonic ** 2)) | |
| p_energy = float(np.mean(y_percussive ** 2)) | |
| hp_ratio = h_energy / max(p_energy, 1e-8) | |
| # Decision: high vocal ratio + high harmonic content = vocals | |
| has_vocals = vocal_ratio > 0.8 and hp_ratio > 1.2 | |
| return { | |
| "has_vocals": has_vocals, | |
| "vocal_ratio": round(vocal_ratio, 4), | |
| "harmonic_percussive_ratio": round(hp_ratio, 4), | |
| "label": "Vocal" if has_vocals else "Instrumental", | |
| } | |
| except Exception as e: | |
| print(f"Vocal detection error: {e}") | |
| return {"has_vocals": False, "vocal_ratio": 0.0, "harmonic_percussive_ratio": 0.0, "label": "Unknown"} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BEAT GENERATOR (Procedural Drum Synthesis) | |
| # (Condensed version of the full beat_generator.py) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BEAT_SR = 44100 | |
| class Complexity(Enum): | |
| MINIMAL = "minimal" | |
| SIMPLE = "simple" | |
| MEDIUM = "medium" | |
| COMPLEX = "complex" | |
| INTRICATE = "intricate" | |
| class Energy(Enum): | |
| SOFT = "soft" | |
| LOW = "low" | |
| MEDIUM = "medium" | |
| HIGH = "high" | |
| INTENSE = "intense" | |
| class Mood(Enum): | |
| DARK = "dark" | |
| MELANCHOLIC = "melancholic" | |
| NEUTRAL = "neutral" | |
| UPLIFTING = "uplifting" | |
| AGGRESSIVE = "aggressive" | |
| class BeatParams: | |
| genre: str | |
| sub_genre: Optional[str] | |
| bpm: float | |
| bars: int | |
| complexity: Complexity | |
| energy: Energy | |
| mood: Mood | |
| time_signature: Tuple[int, int] | |
| swing: float | |
| humanize: float | |
| include_fills: bool | |
| instruments: List[str] | |
| description: str | |
| # Genre defaults | |
| _GENRE_DEFAULTS = { | |
| "hiphop": {"bpm": 90, "range": (70, 110)}, | |
| "trap": {"bpm": 140, "range": (120, 160)}, | |
| "edm": {"bpm": 128, "range": (118, 150)}, | |
| "house": {"bpm": 124, "range": (118, 130)}, | |
| "techno": {"bpm": 130, "range": (120, 145)}, | |
| "dnb": {"bpm": 174, "range": (160, 180)}, | |
| "rock": {"bpm": 120, "range": (100, 150)}, | |
| "metal": {"bpm": 160, "range": (100, 220)}, | |
| "pop": {"bpm": 120, "range": (95, 135)}, | |
| "jazz": {"bpm": 120, "range": (60, 200)}, | |
| "reggae": {"bpm": 80, "range": (65, 100)}, | |
| "funk": {"bpm": 105, "range": (85, 125)}, | |
| "ambient": {"bpm": 75, "range": (50, 100)}, | |
| } | |
| _GENRE_KEYWORDS = { | |
| "hiphop": ["hip hop", "hip-hop", "hiphop", "rap", "boom bap", "lofi", "lo-fi"], | |
| "trap": ["trap", "atlanta trap", "drill", "phonk"], | |
| "edm": ["edm", "electronic dance", "electronic"], | |
| "house": ["house", "deep house", "tech house"], | |
| "techno": ["techno", "industrial", "acid"], | |
| "dnb": ["drum and bass", "dnb", "d&b", "jungle"], | |
| "rock": ["rock", "indie", "punk", "grunge"], | |
| "metal": ["metal", "heavy metal", "thrash", "death metal"], | |
| "pop": ["pop", "mainstream"], | |
| "jazz": ["jazz", "swing", "bebop"], | |
| "reggae": ["reggae", "dub", "ska", "dancehall"], | |
| "funk": ["funk", "funky", "groove", "disco", "soul"], | |
| "ambient": ["ambient", "chill", "chillout", "downtempo"], | |
| } | |
| def _kw_in(kw: str, text: str) -> bool: | |
| pattern = r"(?<![a-z])" + re.escape(kw) + r"(?![a-z])" | |
| return bool(re.search(pattern, text, re.IGNORECASE)) | |
| def parse_prompt(prompt: str) -> BeatParams: | |
| text = prompt.lower().strip() | |
| # Find genre | |
| genre = "hiphop" | |
| for g, keywords in _GENRE_KEYWORDS.items(): | |
| if any(_kw_in(kw, text) for kw in keywords): | |
| genre = g | |
| break | |
| # Find BPM | |
| bpm_match = re.search(r"(\d{2,3})\s*(?:bpm|tempo)", text, re.IGNORECASE) | |
| if bpm_match: | |
| bpm = max(40.0, min(250.0, float(bpm_match.group(1)))) | |
| else: | |
| bpm = _GENRE_DEFAULTS.get(genre, {}).get("bpm", 120) | |
| # Find bars | |
| bars_match = re.search(r"(\d+)\s*(?:bar|bars|measure)", text, re.IGNORECASE) | |
| bars = int(bars_match.group(1)) if bars_match else 4 | |
| bars = max(1, min(32, bars)) | |
| # Complexity | |
| if any(w in text for w in ["complex", "intricate", "busy"]): | |
| complexity = Complexity.COMPLEX | |
| elif any(w in text for w in ["simple", "basic", "minimal"]): | |
| complexity = Complexity.SIMPLE | |
| else: | |
| complexity = Complexity.MEDIUM | |
| # Energy | |
| if any(w in text for w in ["intense", "aggressive", "hard", "heavy"]): | |
| energy = Energy.INTENSE | |
| elif any(w in text for w in ["energetic", "powerful", "driving"]): | |
| energy = Energy.HIGH | |
| elif any(w in text for w in ["soft", "gentle", "quiet"]): | |
| energy = Energy.SOFT | |
| elif any(w in text for w in ["chill", "relaxed", "laid-back"]): | |
| energy = Energy.LOW | |
| else: | |
| energy = Energy.MEDIUM | |
| mood = Mood.NEUTRAL | |
| if any(w in text for w in ["dark", "sinister"]): | |
| mood = Mood.DARK | |
| elif any(w in text for w in ["happy", "uplifting", "bright"]): | |
| mood = Mood.UPLIFTING | |
| instruments = ["kick", "snare", "hihat_c", "hihat_o", "clap"] | |
| desc = f"{mood.value.title()} {energy.value} {genre.title()} beat at {bpm:.0f} BPM, {bars} bars" | |
| return BeatParams( | |
| genre=genre, sub_genre=None, bpm=bpm, bars=bars, | |
| complexity=complexity, energy=energy, mood=mood, | |
| time_signature=(4, 4), swing=0.0, humanize=0.3, | |
| include_fills=False, instruments=instruments, description=desc, | |
| ) | |
| # Drum synthesis functions | |
| def _env(length: int, attack: float = 0.002, decay: float = 0.15): | |
| t = np.linspace(0, 1, length) | |
| env = np.exp(-t / max(decay, 1e-6)) | |
| atk = int(attack * BEAT_SR) | |
| if 0 < atk < length: | |
| env[:atk] *= np.linspace(0, 1, atk) | |
| return env | |
| def synth_kick(dur=0.4): | |
| n = int(dur * BEAT_SR) | |
| t = np.linspace(0, dur, n) | |
| freq = np.linspace(200, 55, n) | |
| phase = 2 * np.pi * np.cumsum(freq) / BEAT_SR | |
| sine = np.sin(phase) * _env(n, 0.001, 0.25) * 0.8 | |
| click_n = int(0.008 * BEAT_SR) | |
| click = np.zeros(n) | |
| click[:click_n] = (np.random.rand(click_n) * 2 - 1) * np.linspace(1, 0, click_n) | |
| return (sine + click * 0.2).astype(np.float32) | |
| def synth_snare(dur=0.25): | |
| n = int(dur * BEAT_SR) | |
| t = np.linspace(0, dur, n) | |
| noise = np.random.randn(n) * _env(n, 0.001, 0.08) * 0.6 | |
| body = np.sin(2 * np.pi * 200 * t) * _env(n, 0.001, 0.05) * 0.5 | |
| return (noise + body).astype(np.float32) | |
| def synth_hihat_c(dur=0.05): | |
| n = int(dur * BEAT_SR) | |
| noise = np.random.randn(n) | |
| filtered = np.diff(noise, prepend=noise[0]) | |
| return (filtered * _env(n, 0.0005, 0.02) * 0.5).astype(np.float32) | |
| def synth_hihat_o(dur=0.3): | |
| n = int(dur * BEAT_SR) | |
| noise = np.random.randn(n) | |
| filtered = np.diff(noise, prepend=noise[0]) | |
| ring = np.sin(2 * np.pi * 6000 * np.linspace(0, dur, n)) | |
| return (filtered * _env(n, 0.001, 0.2) * 0.4 + ring * _env(n, 0.001, 0.2) * 0.1).astype(np.float32) | |
| def synth_clap(dur=0.15): | |
| n = int(dur * BEAT_SR) | |
| noise = np.random.randn(n) | |
| env = np.zeros(n) | |
| for offset_ms in [0, 10, 20, 30]: | |
| offset = int(offset_ms * BEAT_SR / 1000) | |
| if offset < n: | |
| env[offset:] += _env(n - offset, 0.001, 0.04) * (1 - offset_ms / 50) | |
| return (noise * env * 0.5).astype(np.float32) | |
| # Sound bank | |
| _SOUNDS = { | |
| "kick": synth_kick(), | |
| "snare": synth_snare(), | |
| "hihat_c": synth_hihat_c(), | |
| "hihat_o": synth_hihat_o(), | |
| "clap": synth_clap(), | |
| } | |
| def generate_pattern(params: BeatParams) -> Dict[str, List[int]]: | |
| steps = 16 | |
| pattern = {} | |
| genre = params.genre | |
| cx = params.complexity | |
| # Kick | |
| kick = [0] * steps | |
| if genre in ["edm", "house", "techno"]: | |
| for i in range(0, steps, 4): kick[i] = 1 | |
| elif genre == "dnb": | |
| kick[0] = 1; kick[10] = 1 | |
| elif genre in ["hiphop", "trap"]: | |
| kick[0] = 1; kick[6] = 1 | |
| if cx in [Complexity.COMPLEX, Complexity.INTRICATE]: kick[10] = 1 | |
| else: | |
| kick[0] = 1; kick[8] = 1 | |
| pattern["kick"] = kick | |
| # Snare | |
| snare = [0] * steps | |
| if genre == "reggae": | |
| snare[8] = 1 | |
| elif genre == "trap": | |
| snare[12] = 1 | |
| else: | |
| snare[4] = 1; snare[12] = 1 | |
| pattern["snare"] = snare | |
| # Hi-hats | |
| hh_c = [0] * steps | |
| if cx == Complexity.MINIMAL: | |
| for i in range(0, steps, 4): hh_c[i] = 1 | |
| elif cx == Complexity.SIMPLE: | |
| for i in range(0, steps, 2): hh_c[i] = 1 | |
| else: | |
| for i in range(steps): hh_c[i] = 1 | |
| pattern["hihat_c"] = hh_c | |
| hh_o = [0] * steps | |
| if cx != Complexity.MINIMAL: | |
| hh_o[7] = 1; hh_o[15] = 1 | |
| pattern["hihat_o"] = hh_o | |
| # Clap | |
| clap = [0] * steps | |
| if genre in ["trap", "edm", "house"]: | |
| clap[4] = 1; clap[12] = 1 | |
| pattern["clap"] = clap | |
| return pattern | |
| def render_beat(pattern: Dict[str, List[int]], params: BeatParams) -> np.ndarray: | |
| steps_per_bar = 16 | |
| beat_dur = 60.0 / params.bpm | |
| step_dur = beat_dur / 4.0 | |
| bar_samples = int(steps_per_bar * step_dur * BEAT_SR) | |
| total_samples = bar_samples * params.bars | |
| audio = np.zeros(total_samples, dtype=np.float32) | |
| gain_map = {"kick": 0.95, "snare": 0.80, "hihat_c": 0.50, "hihat_o": 0.55, "clap": 0.70} | |
| for instrument, steps in pattern.items(): | |
| sound = _SOUNDS.get(instrument) | |
| if sound is None: | |
| continue | |
| base_gain = gain_map.get(instrument, 0.6) | |
| for bar in range(params.bars): | |
| for step, hit in enumerate(steps): | |
| if not hit: | |
| continue | |
| sample_pos = bar * bar_samples + int(step * step_dur * BEAT_SR) | |
| # Humanize | |
| if params.humanize > 0: | |
| sample_pos += int(np.random.normal(0, params.humanize * 0.01 * BEAT_SR)) | |
| sample_pos = max(0, min(sample_pos, total_samples - 1)) | |
| end = min(sample_pos + len(sound), total_samples) | |
| write_len = end - sample_pos | |
| if write_len > 0: | |
| audio[sample_pos:end] += sound[:write_len] * base_gain | |
| peak = np.abs(audio).max() | |
| if peak > 0.95: | |
| audio = audio * (0.95 / peak) | |
| return audio | |
| def generate_beat_full(prompt: str, output_path: str) -> dict: | |
| params = parse_prompt(prompt) | |
| pattern = generate_pattern(params) | |
| audio = render_beat(pattern, params) | |
| duration = len(audio) / BEAT_SR | |
| sf.write(output_path, audio, BEAT_SR, subtype="PCM_16") | |
| return { | |
| "genre": params.genre, | |
| "bpm": params.bpm, | |
| "bars": params.bars, | |
| "complexity": params.complexity.value, | |
| "description": params.description, | |
| "duration": round(duration, 3), | |
| "pattern": pattern, | |
| "sample_rate": BEAT_SR, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SHAZAM RECOGNITION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def convert_to_wav(input_path): | |
| output_path = input_path.rsplit('.', 1)[0] + '_converted.wav' | |
| try: | |
| cmd = ['ffmpeg', '-y', '-i', input_path, '-ar', '44100', '-ac', '1', '-sample_fmt', 's16', '-f', 'wav', output_path] | |
| result = subprocess.run(cmd, capture_output=True, timeout=30) | |
| if result.returncode == 0 and os.path.exists(output_path): | |
| return output_path | |
| except Exception as e: | |
| print(f"Conversion error: {e}") | |
| return None | |
| async def recognize_shazam(audio_path): | |
| if not RAPIDAPI_KEY: | |
| return None | |
| headers = {"X-RapidAPI-Key": RAPIDAPI_KEY, "X-RapidAPI-Host": SHAZAM_HOST} | |
| try: | |
| with open(audio_path, 'rb') as f: | |
| audio_bytes = f.read() | |
| async with httpx.AsyncClient(timeout=25) as client: | |
| response = await client.post(SHAZAM_URL, headers=headers, files={"file": ("audio.wav", audio_bytes, "audio/wav")}) | |
| if response.status_code != 200: | |
| return None | |
| data = response.json() | |
| track = data.get("track") or data | |
| if not track.get("title") and not track.get("heading"): | |
| return None | |
| title = track.get("title", "Unknown") | |
| artist = track.get("subtitle", "Unknown Artist") | |
| album = cover = year = spotify_url = apple_music_url = None | |
| shazam_url = track.get("url") | |
| sections = track.get("sections", []) | |
| for section in sections: | |
| if section.get("type") == "SONG": | |
| for meta in section.get("metadata", []): | |
| if meta.get("title") == "Album": album = meta.get("text") | |
| elif meta.get("title") == "Released": year = meta.get("text") | |
| images = track.get("images", {}) | |
| cover = images.get("coverarthq") or images.get("coverart") | |
| hub = track.get("hub", {}) | |
| for provider in hub.get("providers", []): | |
| ptype = provider.get("type", "").upper() | |
| for action in provider.get("actions", []): | |
| uri = action.get("uri", "") | |
| if ptype == "SPOTIFY" and uri: spotify_url = uri | |
| elif ptype == "APPLE" and uri: apple_music_url = uri | |
| score = len(data.get("matches", [])) | |
| return {"title": title, "artist": artist, "album": album, "cover": cover, "year": year, | |
| "spotify": spotify_url, "apple_music": apple_music_url, "shazam_url": shazam_url, | |
| "score": max(score, 1), "source": "shazam"} | |
| except Exception as e: | |
| print(f"Shazam error: {e}") | |
| return None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # API ROUTES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return {"status": "ok", "service": "AutoMixAI Backend v2.0", "features": [ | |
| "upload", "analyze", "mix (advanced DJ)", "generate", "recognize" | |
| ]} | |
| def health(): | |
| return {"status": "healthy"} | |
| # ββ Upload βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ALLOWED_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".aac"} | |
| async def upload_audio(file: UploadFile = File(...)): | |
| original_name = file.filename or "unknown" | |
| ext = Path(original_name).suffix.lower() | |
| if ext not in ALLOWED_EXTENSIONS: | |
| raise HTTPException(status_code=400, detail=f"Unsupported file type '{ext}'.") | |
| file_id = generate_file_id() | |
| save_path = UPLOAD_DIR / f"{file_id}{ext}" | |
| try: | |
| contents = await file.read() | |
| save_path.write_bytes(contents) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail="Failed to save file.") from exc | |
| try: | |
| info = get_audio_info(str(save_path)) | |
| except Exception as exc: | |
| save_path.unlink(missing_ok=True) | |
| raise HTTPException(status_code=400, detail="Not a valid audio file.") from exc | |
| return UploadResponse(file_id=file_id, filename=original_name, duration=round(info["duration"], 2)) | |
| # ββ Analyze ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def analyze_audio(request: AnalyzeRequest): | |
| try: | |
| file_path = find_upload(request.file_id) | |
| except FileNotFoundError: | |
| raise HTTPException(status_code=404, detail=f"File '{request.file_id}' not found.") | |
| file_path_str = str(file_path) | |
| # Load for librosa analysis | |
| y, sr = load_audio(file_path_str) | |
| beat_times = detect_beats(y, sr) | |
| bpm = estimate_bpm_from_beats(beat_times) if len(beat_times) >= 2 else estimate_bpm_librosa(y, sr) | |
| # ML Genre classification (transformer model) | |
| genre_result = classify_genre(file_path_str) | |
| # ML Mood classification (transformer model) | |
| mood_result = classify_mood(file_path_str) | |
| # Vocal detection (spectral analysis) | |
| vocal_result = detect_vocals(y, sr) | |
| # Energy (RMS-based) | |
| rms = float(np.mean(librosa.feature.rms(y=y))) | |
| if rms > 0.12: energy = "high" | |
| elif rms > 0.06: energy = "medium" | |
| else: energy = "low" | |
| return AnalysisResponse( | |
| file_id=request.file_id, | |
| bpm=bpm, | |
| beat_times=beat_times, | |
| duration=round(len(y) / sr, 2), | |
| sample_rate=sr, | |
| genre=genre_result["genre"], | |
| genre_confidence=genre_result["confidence"], | |
| genre_top3=genre_result.get("top3", []), | |
| mood=mood_result["mood"], | |
| has_vocals=vocal_result["has_vocals"], | |
| energy=energy, | |
| ) | |
| # ββ Mix ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def mix_tracks(request: MixRequest): | |
| try: | |
| path_a = find_upload(request.file_id_a) | |
| except FileNotFoundError: | |
| raise HTTPException(status_code=404, detail=f"Track A '{request.file_id_a}' not found.") | |
| try: | |
| path_b = find_upload(request.file_id_b) | |
| except FileNotFoundError: | |
| raise HTTPException(status_code=404, detail=f"Track B '{request.file_id_b}' not found.") | |
| output_id = generate_file_id() | |
| output_path = OUTPUT_DIR / f"{output_id}.wav" | |
| try: | |
| result = create_advanced_mix( | |
| path_a=str(path_a), | |
| path_b=str(path_b), | |
| output_path=str(output_path), | |
| crossfade_duration=request.crossfade_duration, | |
| bass_boost=request.bass_boost, | |
| brightness=request.brightness, | |
| vocal_boost=request.vocal_boost, | |
| pan_a=request.pan_a, | |
| pan_b=request.pan_b, | |
| eq_transition=request.eq_transition, | |
| ) | |
| except Exception as exc: | |
| print(f"Mix error: {exc}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Mixing failed: {str(exc)}") from exc | |
| return MixResponse( | |
| output_file_id=output_id, | |
| duration=result["duration"], | |
| bpm_a=result["bpm_a"], | |
| bpm_b=result["bpm_b"], | |
| target_bpm=result["target_bpm"], | |
| ) | |
| # ββ Generate (Procedural Synth) ββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def generate_beat_route(request: GenerateBeatRequest): | |
| output_id = generate_file_id() | |
| output_path = OUTPUT_DIR / f"{output_id}.wav" | |
| effective_prompt = request.prompt | |
| if request.bars: | |
| effective_prompt = f"{request.prompt} {request.bars} bars" | |
| try: | |
| result = generate_beat_full(effective_prompt, str(output_path)) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=f"Beat generation failed: {exc}") from exc | |
| pattern_raw = result["pattern"] | |
| return GenerateBeatResponse( | |
| output_file_id=output_id, | |
| genre=result["genre"], | |
| bpm=result["bpm"], | |
| bars=result["bars"], | |
| complexity=result["complexity"], | |
| description=result["description"], | |
| duration=result["duration"], | |
| pattern=PatternInfo( | |
| kick=pattern_raw.get("kick", [0]*16), | |
| snare=pattern_raw.get("snare", [0]*16), | |
| hihat_c=pattern_raw.get("hihat_c", [0]*16), | |
| hihat_o=pattern_raw.get("hihat_o", [0]*16), | |
| clap=pattern_raw.get("clap", [0]*16), | |
| ), | |
| ) | |
| # ββ Generate AI (MusicGen via HF Inference API) βββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| MUSICGEN_API_URL = "https://router.huggingface.co/hf-inference/models/facebook/musicgen-small" | |
| class GenerateAIRequest(BaseModel): | |
| prompt: str = Field(..., min_length=3, max_length=500) | |
| duration: int = Field(default=10, ge=3, le=30) | |
| class GenerateAIResponse(BaseModel): | |
| output_file_id: str | |
| prompt: str | |
| duration: float | |
| model: str = "facebook/musicgen-small" | |
| sample_rate: int = 32000 | |
| message: str = "AI beat generated successfully." | |
| async def generate_beat_ai(request: GenerateAIRequest): | |
| """Generate a beat using Meta's MusicGen via HuggingFace Inference API (free GPU).""" | |
| output_id = generate_file_id() | |
| output_path = OUTPUT_DIR / f"{output_id}.wav" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "x-wait-for-model": "true", # Wait for model to load instead of 503 | |
| } | |
| if HF_TOKEN: | |
| headers["Authorization"] = f"Bearer {HF_TOKEN}" | |
| payload = { | |
| "inputs": request.prompt, | |
| } | |
| try: | |
| print(f"MusicGen AI generating: '{request.prompt}' ({request.duration}s)") | |
| async with httpx.AsyncClient(timeout=300) as client: | |
| response = await client.post( | |
| MUSICGEN_API_URL, | |
| headers=headers, | |
| json=payload, | |
| ) | |
| print(f"HF API response: status={response.status_code}, content-type={response.headers.get('content-type', 'unknown')}, size={len(response.content)} bytes") | |
| if response.status_code == 503: | |
| error_data = response.json() if response.headers.get("content-type", "").startswith("application/json") else {} | |
| wait_time = error_data.get("estimated_time", 30) | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"MusicGen model is loading, please try again in ~{int(wait_time)} seconds." | |
| ) | |
| if response.status_code != 200: | |
| error_msg = response.text[:500] | |
| print(f"HF API error: {error_msg}") | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"HF Inference API error ({response.status_code}): {error_msg}" | |
| ) | |
| # Response is raw audio bytes (FLAC format) | |
| audio_bytes = response.content | |
| # Save the raw audio first | |
| temp_path = str(output_path).replace(".wav", "_raw.flac") | |
| with open(temp_path, "wb") as f: | |
| f.write(audio_bytes) | |
| # Convert to WAV using librosa | |
| y, sr = librosa.load(temp_path, sr=None, mono=True) | |
| sf.write(str(output_path), y, sr, subtype="PCM_16") | |
| # Clean up temp | |
| os.remove(temp_path) | |
| actual_duration = round(len(y) / sr, 2) | |
| print(f"MusicGen AI complete: {actual_duration}s") | |
| except HTTPException: | |
| raise | |
| except Exception as exc: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"AI generation failed: {str(exc)}") from exc | |
| return GenerateAIResponse( | |
| output_file_id=output_id, | |
| prompt=request.prompt, | |
| duration=actual_duration, | |
| sample_rate=int(sr), | |
| ) | |
| # ββ Output βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def download_output(file_id: str): | |
| output_path = OUTPUT_DIR / f"{file_id}.wav" | |
| if not output_path.exists(): | |
| raise HTTPException(status_code=404, detail=f"Output file '{file_id}' not found.") | |
| return FileResponse(str(output_path), media_type="audio/wav", filename=f"automix_{file_id}.wav") | |
| # ββ Recognize ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def recognize(file: UploadFile = File(...)): | |
| if not file: | |
| raise HTTPException(status_code=400, detail="No file uploaded") | |
| suffix = os.path.splitext(file.filename or "audio.webm")[1] or ".webm" | |
| tmp_path = os.path.join(tempfile.gettempdir(), f"shazam_{uuid.uuid4().hex}{suffix}") | |
| converted_path = None | |
| try: | |
| content = await file.read() | |
| if len(content) < 500: | |
| return {"status": "error", "message": "Audio too small"} | |
| with open(tmp_path, 'wb') as f: | |
| f.write(content) | |
| converted_path = convert_to_wav(tmp_path) | |
| work_path = converted_path if converted_path else tmp_path | |
| result = await recognize_shazam(work_path) | |
| if result: | |
| return { | |
| "status": "found", "title": result["title"], "artist": result["artist"], | |
| "album": result.get("album"), "cover": result.get("cover"), | |
| "year": result.get("year"), "spotify": result.get("spotify"), | |
| "apple_music": result.get("apple_music"), "shazam_url": result.get("shazam_url"), | |
| "score": result.get("score", 0), "source": result.get("source", "unknown"), | |
| "match_quality": "high", "is_early": True, | |
| } | |
| else: | |
| return {"status": "not_found", "message": "No song matched.", "is_early": False} | |
| except Exception as e: | |
| return {"status": "error", "message": f"Recognition failed: {str(e)}"} | |
| finally: | |
| for path in [tmp_path, converted_path]: | |
| if path: | |
| try: os.unlink(path) | |
| except: pass | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ENTRYPOINT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |