File size: 3,990 Bytes
6585c4c
375db4a
d59c282
375db4a
 
 
 
 
 
 
 
 
 
 
 
 
 
57f2351
 
d178bba
57f2351
d178bba
57f2351
 
088b7f2
 
d178bba
57f2351
088b7f2
57f2351
6585c4c
57f2351
d178bba
 
57f2351
d178bba
 
1d43544
d178bba
1d43544
 
d178bba
 
57f2351
 
 
 
 
 
4c26d1d
088b7f2
57f2351
4c26d1d
 
375db4a
57f2351
 
 
4c26d1d
 
 
375db4a
 
4c26d1d
57f2351
 
 
375db4a
57f2351
 
 
 
 
 
 
 
 
 
 
 
088b7f2
57f2351
 
 
 
 
088b7f2
 
679b885
57f2351
 
 
 
 
 
d59c282
 
 
 
088b7f2
57f2351
 
 
 
d59c282
088b7f2
 
 
 
57f2351
d59c282
57f2351
d178bba
 
 
57f2351
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import logging
from starlette.background import BackgroundTask # <-- IMPORT THE FIX

# --- FIX FOR ALL PERMISSION ERRORS ---
# Set environment variables BEFORE importing torch or diffusers.
# This forces all underlying libraries (huggingface_hub, torch, etc.)
# to use a writable directory inside /tmp, avoiding any permission errors.
CACHE_DIR = "/tmp/huggingface_cache"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_HUB_CACHE'] = os.path.join(CACHE_DIR, 'hub')
os.environ['TORCH_HOME'] = os.path.join(CACHE_DIR, 'torch')
os.makedirs(os.path.join(CACHE_DIR, 'hub'), exist_ok=True)
os.makedirs(os.path.join(CACHE_DIR, 'torch'), exist_ok=True)


# Now it's safe to import the other libraries
import torch
import tempfile
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from diffusers import AudioLDMPipeline
from scipy.io.wavfile import write as write_wav
import numpy as np

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("app")

# --- App Setup ---
app = FastAPI()

# Allow all origins for CORS, useful for development
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- Pydantic Model for Request Body ---
class AudioRequest(BaseModel):
    prompt: str

# --- Model Loading ---
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if "cuda" in device else torch.float32

logger.info(f"Using device: {device} with dtype: {torch_dtype}")
logger.info(f"Using model cache directory: {CACHE_DIR}")

pipe = None
try:
    # Use the stable, recommended model
    repo_id = "cvssp/audioldm-s-full-v2"
    pipe = AudioLDMPipeline.from_pretrained(
        repo_id, 
        torch_dtype=torch_dtype,
        # cache_dir is still good practice but the environment variables are the real fix
        cache_dir=CACHE_DIR 
    )
    pipe = pipe.to(device)
    logger.info(f"Successfully loaded model: {repo_id}")
except Exception as e:
    logger.error(f"Failed to load the model: {e}", exc_info=True)
    pipe = None # Ensure pipe is None if loading fails

# --- API Endpoint ---
@app.post("/generate-audio")
async def generate_audio_endpoint(request: AudioRequest):
    if pipe is None:
        raise HTTPException(status_code=503, detail="Model is not available. Check server logs for loading errors.")

    prompt = request.prompt
    logger.info(f"Generating audio for prompt: '{prompt}'")
    
    temp_file_path = ""
    try:
        audio = pipe(
            prompt, 
            num_inference_steps=200, 
            audio_length_in_s=5.0,
            guidance_scale=7.0
        ).audios[0]

        sample_rate = 44100

        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
            temp_file_path = temp_file.name
            audio_int16 = (audio * 32767).astype(np.int16)
            write_wav(temp_file_path, sample_rate, audio_int16)
            logger.info(f"Audio saved to temporary file: {temp_file_path}")
        
        # ### THIS IS THE FIX ###
        # Create a background task to delete the file AFTER the response is sent.
        cleanup_task = BackgroundTask(os.remove, temp_file_path)

        return FileResponse(
            path=temp_file_path, 
            media_type='audio/wav', 
            filename=f"{prompt[:50].replace(' ', '_')}.wav",
            background=cleanup_task # Use the background task here
        )

    except Exception as e:
        logger.error(f"Error during audio generation for prompt '{prompt}': {e}", exc_info=True)
        if temp_file_path and os.path.exists(temp_file_path):
            os.remove(temp_file_path) # Clean up if something else goes wrong
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
def read_root():
    return {"status": "Audio generation API is running."}