hswift commited on
Commit
088b7f2
·
verified ·
1 Parent(s): 9f76617

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -46
app.py CHANGED
@@ -2,15 +2,40 @@ from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import io
5
- import wave
6
- import math
7
- import struct
 
 
8
 
9
- # Initialize the FastAPI app
 
 
 
 
10
  app = FastAPI()
11
 
12
- # IMPORTANT: Add CORS middleware to allow requests from your frontend
13
- # This is crucial for connecting the two parts.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"], # Allows all origins
@@ -19,55 +44,57 @@ app.add_middleware(
19
  allow_headers=["*"], # Allows all headers
20
  )
21
 
22
- def create_mock_wav_in_memory(prompt: str) -> io.BytesIO:
23
- """Generates a 1-second, 440Hz sine wave WAV file in memory."""
24
- sample_rate = 44100
25
- duration_seconds = 1
26
- frequency = 440.0 # A4 note
27
- num_samples = duration_seconds * sample_rate
28
-
29
- # Use io.BytesIO to build the WAV file in memory
30
- wav_file_in_memory = io.BytesIO()
31
-
32
- with wave.open(wav_file_in_memory, 'wb') as w:
33
- w.setnchannels(1) # Mono
34
- w.setsampwidth(2) # 16-bit PCM
35
- w.setframerate(sample_rate)
36
-
37
- for i in range(num_samples):
38
- # Calculate the sample value for the sine wave
39
- value = int(32767.0 * math.sin(2 * math.pi * frequency * i / sample_rate))
40
- # Pack the value as a 16-bit signed integer
41
- data = struct.pack('<h', value)
42
- w.writeframesraw(data)
43
-
44
- # Go back to the beginning of the in-memory file so it can be read
45
- wav_file_in_memory.seek(0)
46
- return wav_file_in_memory
47
-
48
  @app.post("/generate-audio")
49
  async def generate_audio_endpoint(payload: dict):
50
  """
51
- This endpoint receives a text prompt and returns a generated audio file.
52
  """
 
 
 
 
53
  prompt = payload.get("prompt")
54
  if not prompt:
55
  raise HTTPException(status_code=400, detail="A 'prompt' is required in the request body.")
56
 
57
- # Generate the mock audio data
58
- audio_data_stream = create_mock_wav_in_memory(prompt)
59
-
60
- # Create a safe filename from the prompt
61
- safe_filename = "".join(c for c in prompt if c.isalnum() or c in (' ', '_')).rstrip()[:50] + ".wav"
62
-
63
- # Return the in-memory WAV file as a streaming response
64
- return StreamingResponse(
65
- audio_data_stream,
66
- media_type="audio/wav",
67
- headers={"Content-Disposition": f"attachment; filename=\"{safe_filename}\""}
68
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  @app.get("/")
71
  def read_root():
72
  """A simple root endpoint to confirm the API is running."""
73
- return {"message": "Mock Audio Generation API is running."}
 
2
  from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import io
5
+ import torch
6
+ from diffusers import AudioLDM2Pipeline
7
+ from scipy.io.wavfile import write as write_wav
8
+ import numpy as np
9
+ import logging
10
 
11
+ # --- Setup Logging ---
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # --- Initialize FastAPI App ---
16
  app = FastAPI()
17
 
18
+ # --- Model Loading ---
19
+ # This section loads the AI model when the application starts.
20
+ # This is crucial for performance, so it only happens once.
21
+ MODEL_REPO = "cvssp/audioldm2"
22
+ pipeline = None
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
25
+
26
+ try:
27
+ logger.info(f"Attempting to load model '{MODEL_REPO}' on device: {device} with dtype: {torch_dtype}")
28
+ # Load the pre-trained AudioLDM2 pipeline
29
+ pipeline = AudioLDM2Pipeline.from_pretrained(MODEL_REPO, torch_dtype=torch_dtype)
30
+ pipeline = pipeline.to(device)
31
+ logger.info("Model loaded successfully and moved to device.")
32
+ except Exception as e:
33
+ logger.error(f"Fatal error during model loading: {e}", exc_info=True)
34
+ # If the model fails to load, the 'pipeline' variable will remain None.
35
+ # The endpoint will then report an error.
36
+
37
+ # --- CORS Middleware ---
38
+ # Allows the frontend website to communicate with this API
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"], # Allows all origins
 
44
  allow_headers=["*"], # Allows all headers
45
  )
46
 
47
+ # --- API Endpoints ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @app.post("/generate-audio")
49
  async def generate_audio_endpoint(payload: dict):
50
  """
51
+ Receives a text prompt and returns a generated WAV audio file.
52
  """
53
+ if pipeline is None:
54
+ logger.error("Request received, but model is not loaded.")
55
+ raise HTTPException(status_code=503, detail="Model is not available or failed to load. Please check the server logs.")
56
+
57
  prompt = payload.get("prompt")
58
  if not prompt:
59
  raise HTTPException(status_code=400, detail="A 'prompt' is required in the request body.")
60
 
61
+ try:
62
+ logger.info(f"Generating audio for prompt: '{prompt}'")
63
+
64
+ # Generate audio. The model works well with negative prompts to guide it.
65
+ audio = pipeline(
66
+ prompt,
67
+ negative_prompt="Low quality, noisy, muffled, mono", # Helps improve quality
68
+ num_inference_steps=200, # Higher steps can improve quality
69
+ audio_length_in_s=2.5 # Generate 2.5-second clips
70
+ ).audios[0]
71
+
72
+ # The model output is a numpy array with float values from -1.0 to 1.0.
73
+ # We need to convert it to a 16-bit PCM WAV file.
74
+ sample_rate = 16000 # The model's default sample rate
75
+
76
+ # Scale to 16-bit integer range
77
+ audio_int16 = np.int16(audio * 32767)
78
+
79
+ # Use io.BytesIO to build the WAV file in memory
80
+ wav_file_in_memory = io.BytesIO()
81
+ write_wav(wav_file_in_memory, sample_rate, audio_int16)
82
+ wav_file_in_memory.seek(0) # Rewind to the beginning of the stream
83
+
84
+ safe_filename = "".join(c for c in prompt if c.isalnum() or c in (' ', '_')).rstrip()[:50] + ".wav"
85
+
86
+ logger.info(f"Successfully generated audio for prompt: '{prompt}'")
87
+ return StreamingResponse(
88
+ wav_file_in_memory,
89
+ media_type="audio/wav",
90
+ headers={"Content-Disposition": f"attachment; filename=\"{safe_filename}\""}
91
+ )
92
+
93
+ except Exception as e:
94
+ logger.error(f"Error during audio generation for prompt '{prompt}': {e}", exc_info=True)
95
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred during audio generation.")
96
 
97
  @app.get("/")
98
  def read_root():
99
  """A simple root endpoint to confirm the API is running."""
100
+ return {"message": "Decent Sampler Audio Generation API is running."}