Spaces:
Running
Running
File size: 3,990 Bytes
6585c4c 375db4a d59c282 375db4a 57f2351 d178bba 57f2351 d178bba 57f2351 088b7f2 d178bba 57f2351 088b7f2 57f2351 6585c4c 57f2351 d178bba 57f2351 d178bba 1d43544 d178bba 1d43544 d178bba 57f2351 4c26d1d 088b7f2 57f2351 4c26d1d 375db4a 57f2351 4c26d1d 375db4a 4c26d1d 57f2351 375db4a 57f2351 088b7f2 57f2351 088b7f2 679b885 57f2351 d59c282 088b7f2 57f2351 d59c282 088b7f2 57f2351 d59c282 57f2351 d178bba 57f2351 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import logging
from starlette.background import BackgroundTask # <-- IMPORT THE FIX
# --- FIX FOR ALL PERMISSION ERRORS ---
# Set environment variables BEFORE importing torch or diffusers.
# This forces all underlying libraries (huggingface_hub, torch, etc.)
# to use a writable directory inside /tmp, avoiding any permission errors.
CACHE_DIR = "/tmp/huggingface_cache"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_HUB_CACHE'] = os.path.join(CACHE_DIR, 'hub')
os.environ['TORCH_HOME'] = os.path.join(CACHE_DIR, 'torch')
os.makedirs(os.path.join(CACHE_DIR, 'hub'), exist_ok=True)
os.makedirs(os.path.join(CACHE_DIR, 'torch'), exist_ok=True)
# Now it's safe to import the other libraries
import torch
import tempfile
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from diffusers import AudioLDMPipeline
from scipy.io.wavfile import write as write_wav
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("app")
# --- App Setup ---
app = FastAPI()
# Allow all origins for CORS, useful for development
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Pydantic Model for Request Body ---
class AudioRequest(BaseModel):
prompt: str
# --- Model Loading ---
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if "cuda" in device else torch.float32
logger.info(f"Using device: {device} with dtype: {torch_dtype}")
logger.info(f"Using model cache directory: {CACHE_DIR}")
pipe = None
try:
# Use the stable, recommended model
repo_id = "cvssp/audioldm-s-full-v2"
pipe = AudioLDMPipeline.from_pretrained(
repo_id,
torch_dtype=torch_dtype,
# cache_dir is still good practice but the environment variables are the real fix
cache_dir=CACHE_DIR
)
pipe = pipe.to(device)
logger.info(f"Successfully loaded model: {repo_id}")
except Exception as e:
logger.error(f"Failed to load the model: {e}", exc_info=True)
pipe = None # Ensure pipe is None if loading fails
# --- API Endpoint ---
@app.post("/generate-audio")
async def generate_audio_endpoint(request: AudioRequest):
if pipe is None:
raise HTTPException(status_code=503, detail="Model is not available. Check server logs for loading errors.")
prompt = request.prompt
logger.info(f"Generating audio for prompt: '{prompt}'")
temp_file_path = ""
try:
audio = pipe(
prompt,
num_inference_steps=200,
audio_length_in_s=5.0,
guidance_scale=7.0
).audios[0]
sample_rate = 44100
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file_path = temp_file.name
audio_int16 = (audio * 32767).astype(np.int16)
write_wav(temp_file_path, sample_rate, audio_int16)
logger.info(f"Audio saved to temporary file: {temp_file_path}")
# ### THIS IS THE FIX ###
# Create a background task to delete the file AFTER the response is sent.
cleanup_task = BackgroundTask(os.remove, temp_file_path)
return FileResponse(
path=temp_file_path,
media_type='audio/wav',
filename=f"{prompt[:50].replace(' ', '_')}.wav",
background=cleanup_task # Use the background task here
)
except Exception as e:
logger.error(f"Error during audio generation for prompt '{prompt}': {e}", exc_info=True)
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path) # Clean up if something else goes wrong
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def read_root():
return {"status": "Audio generation API is running."} |