Spaces:
Sleeping
Sleeping
| """ | |
| AutoMixAI Beat Generator β HuggingFace Space | |
| AI-powered music/beat generation using Meta's MusicGen model. | |
| Generates studio-quality beats, loops, and music from text prompts. | |
| Endpoints: | |
| POST /generate Generate beat/music from text prompt | |
| GET /output/{id} Download generated audio | |
| GET /health Health check | |
| """ | |
| import os | |
| import uuid | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from transformers import AutoProcessor, MusicgenForConditionalGeneration | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel, Field | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| OUTPUT_DIR = Path(tempfile.gettempdir()) / "automixai_beats" | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Model selection: small for speed, medium for quality | |
| MODEL_ID = os.environ.get("MUSICGEN_MODEL", "facebook/musicgen-small") | |
| SAMPLE_RATE = 32000 # MusicGen outputs at 32kHz | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FASTAPI APP | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="AutoMixAI Beat Generator", | |
| description="AI-powered beat/music generation using Meta's MusicGen.", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SCHEMAS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GenerateRequest(BaseModel): | |
| prompt: str = Field(..., min_length=3, max_length=500, | |
| description="Text prompt describing the beat/music to generate") | |
| duration: int = Field(default=10, ge=3, le=30, | |
| description="Duration in seconds (3-30)") | |
| temperature: float = Field(default=1.0, ge=0.5, le=1.5, | |
| description="Generation temperature: lower=more predictable, higher=more creative") | |
| guidance_scale: float = Field(default=3.0, ge=1.0, le=10.0, | |
| description="How closely to follow the prompt (higher=stricter)") | |
| class GenerateResponse(BaseModel): | |
| output_file_id: str | |
| prompt: str | |
| duration: float | |
| model: str | |
| sample_rate: int | |
| message: str = "Beat generated successfully." | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MUSICGEN MODEL (Lazy-loaded) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model = None | |
| _processor = None | |
| def _load_model(): | |
| """Lazy-load the MusicGen model and processor.""" | |
| global _model, _processor | |
| if _model is None: | |
| print(f"Loading MusicGen model: {MODEL_ID}") | |
| start = time.time() | |
| _processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| _model = MusicgenForConditionalGeneration.from_pretrained(MODEL_ID) | |
| # Use GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| _model = _model.to(device) | |
| if device == "cuda": | |
| _model = _model.half() # FP16 for faster GPU inference | |
| elapsed = time.time() - start | |
| print(f"MusicGen loaded on {device} in {elapsed:.1f}s") | |
| return _model, _processor | |
| def generate_music(prompt: str, duration: int = 10, temperature: float = 1.0, | |
| guidance_scale: float = 3.0) -> tuple: | |
| """ | |
| Generate music/beat from text prompt using MusicGen. | |
| Returns (audio_array, sample_rate) | |
| """ | |
| model, processor = _load_model() | |
| device = next(model.parameters()).device | |
| # Process the prompt | |
| inputs = processor( | |
| text=[prompt], | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(device) | |
| # Calculate max_new_tokens from duration | |
| # MusicGen generates at ~50 tokens/second at 32kHz | |
| tokens_per_second = 50 | |
| max_new_tokens = int(duration * tokens_per_second) | |
| print(f"Generating: '{prompt}' ({duration}s, temp={temperature}, guidance={guidance_scale})") | |
| start = time.time() | |
| with torch.no_grad(): | |
| audio_values = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| guidance_scale=guidance_scale, | |
| do_sample=True, | |
| ) | |
| elapsed = time.time() - start | |
| print(f"Generation complete in {elapsed:.1f}s") | |
| # Convert to numpy | |
| audio = audio_values[0, 0].cpu().numpy() | |
| # Normalize to prevent clipping | |
| peak = np.max(np.abs(audio)) | |
| if peak > 0: | |
| audio = audio / peak * 0.95 | |
| return audio, SAMPLE_RATE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # API ROUTES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return { | |
| "status": "ok", | |
| "service": "AutoMixAI Beat Generator v1.0", | |
| "model": MODEL_ID, | |
| "features": ["text-to-music", "text-to-beat"], | |
| } | |
| def health(): | |
| return {"status": "healthy", "model": MODEL_ID} | |
| async def generate_beat(request: GenerateRequest): | |
| """Generate a beat/music clip from a text prompt using MusicGen.""" | |
| output_id = uuid.uuid4().hex | |
| output_path = OUTPUT_DIR / f"{output_id}.wav" | |
| try: | |
| audio, sr = generate_music( | |
| prompt=request.prompt, | |
| duration=request.duration, | |
| temperature=request.temperature, | |
| guidance_scale=request.guidance_scale, | |
| ) | |
| # Save as WAV | |
| sf.write(str(output_path), audio, sr, subtype="PCM_16") | |
| actual_duration = round(len(audio) / sr, 2) | |
| except Exception as exc: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(exc)}") from exc | |
| return GenerateResponse( | |
| output_file_id=output_id, | |
| prompt=request.prompt, | |
| duration=actual_duration, | |
| model=MODEL_ID, | |
| sample_rate=sr, | |
| ) | |
| async def download_output(file_id: str): | |
| """Download a generated audio file.""" | |
| output_path = OUTPUT_DIR / f"{file_id}.wav" | |
| if not output_path.exists(): | |
| raise HTTPException(status_code=404, detail=f"Output '{file_id}' not found.") | |
| return FileResponse(str(output_path), media_type="audio/wav", | |
| filename=f"automix_beat_{file_id}.wav") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ENTRYPOINT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |