hswift commited on
Commit
1d43544
·
verified ·
1 Parent(s): 6585c4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -30
app.py CHANGED
@@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
6
  import io
7
  import torch
8
  from diffusers import AudioLDM2Pipeline
 
9
  from scipy.io.wavfile import write as write_wav
10
  import numpy as np
11
 
@@ -14,12 +15,9 @@ logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
  # --- CRITICAL FIX for Hugging Face Spaces Permissions ---
17
- # Set the cache directory for all Hugging Face libraries BEFORE they are used.
18
- # This forces the model download and any temporary files to a writable location.
19
  cache_dir = "/tmp/huggingface_cache"
20
  os.environ["HF_HOME"] = cache_dir
21
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
- # Create the directory if it doesn't exist, just in case.
23
  os.makedirs(cache_dir, exist_ok=True)
24
  logger.info(f"Hugging Face cache directory globally set to: {cache_dir}")
25
 
@@ -27,7 +25,6 @@ logger.info(f"Hugging Face cache directory globally set to: {cache_dir}")
27
  app = FastAPI()
28
 
29
  # --- Model Loading ---
30
- # This section loads the AI model when the application starts.
31
  MODEL_REPO = "cvssp/audioldm2"
32
  pipeline = None
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -35,39 +32,45 @@ torch_dtype = torch.float16 if device == "cuda" else torch.float32
35
 
36
  try:
37
  logger.info(f"Attempting to load model '{MODEL_REPO}' on device: {device} with dtype: {torch_dtype}")
38
- # The cache_dir argument is now redundant because of the environment variable,
39
- # but we'll leave it for extra safety.
 
 
 
 
 
 
 
 
 
 
40
  pipeline = AudioLDM2Pipeline.from_pretrained(
41
- MODEL_REPO,
42
  torch_dtype=torch_dtype,
43
- cache_dir=cache_dir
 
44
  )
45
  pipeline = pipeline.to(device)
46
  logger.info("Model loaded successfully and moved to device.")
47
  except Exception as e:
48
  logger.error(f"Fatal error during model loading: {e}", exc_info=True)
49
  # If the model fails to load, the 'pipeline' variable will remain None.
50
- # The endpoint will then report an error.
51
 
52
  # --- CORS Middleware ---
53
- # Allows the frontend website to communicate with this API
54
  app.add_middleware(
55
  CORSMiddleware,
56
- allow_origins=["*"], # Allows all origins
57
  allow_credentials=True,
58
- allow_methods=["*"], # Allows all methods
59
- allow_headers=["*"], # Allows all headers
60
  )
61
 
62
  # --- API Endpoints ---
63
  @app.post("/generate-audio")
64
  async def generate_audio_endpoint(payload: dict):
65
- """
66
- Receives a text prompt and returns a generated WAV audio file.
67
- """
68
  if pipeline is None:
69
  logger.error("Request received, but model is not loaded.")
70
- raise HTTPException(status_code=503, detail="Model is not available or failed to load. Please check the server logs.")
71
 
72
  prompt = payload.get("prompt")
73
  if not prompt:
@@ -76,25 +79,18 @@ async def generate_audio_endpoint(payload: dict):
76
  try:
77
  logger.info(f"Generating audio for prompt: '{prompt}'")
78
 
79
- # Generate audio. The model works well with negative prompts to guide it.
80
  audio = pipeline(
81
  prompt,
82
- negative_prompt="Low quality, noisy, muffled, mono", # Helps improve quality
83
- num_inference_steps=200, # Higher steps can improve quality
84
- audio_length_in_s=2.5 # Generate 2.5-second clips
85
  ).audios[0]
86
 
87
- # The model output is a numpy array with float values from -1.0 to 1.0.
88
- # We need to convert it to a 16-bit PCM WAV file.
89
- sample_rate = 16000 # The model's default sample rate
90
-
91
- # Scale to 16-bit integer range
92
  audio_int16 = np.int16(audio * 32767)
93
-
94
- # Use io.BytesIO to build the WAV file in memory
95
  wav_file_in_memory = io.BytesIO()
96
  write_wav(wav_file_in_memory, sample_rate, audio_int16)
97
- wav_file_in_memory.seek(0) # Rewind to the beginning of the stream
98
 
99
  safe_filename = "".join(c for c in prompt if c.isalnum() or c in (' ', '_')).rstrip()[:50] + ".wav"
100
 
@@ -111,5 +107,4 @@ async def generate_audio_endpoint(payload: dict):
111
 
112
  @app.get("/")
113
  def read_root():
114
- """A simple root endpoint to confirm the API is running."""
115
  return {"message": "Decent Sampler Audio Generation API is running."}
 
6
  import io
7
  import torch
8
  from diffusers import AudioLDM2Pipeline
9
+ from transformers import GPT2LMHeadModel # <-- IMPORT THE CORRECT MODEL TYPE
10
  from scipy.io.wavfile import write as write_wav
11
  import numpy as np
12
 
 
15
  logger = logging.getLogger(__name__)
16
 
17
  # --- CRITICAL FIX for Hugging Face Spaces Permissions ---
 
 
18
  cache_dir = "/tmp/huggingface_cache"
19
  os.environ["HF_HOME"] = cache_dir
20
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
 
21
  os.makedirs(cache_dir, exist_ok=True)
22
  logger.info(f"Hugging Face cache directory globally set to: {cache_dir}")
23
 
 
25
  app = FastAPI()
26
 
27
  # --- Model Loading ---
 
28
  MODEL_REPO = "cvssp/audioldm2"
29
  pipeline = None
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
32
 
33
  try:
34
  logger.info(f"Attempting to load model '{MODEL_REPO}' on device: {device} with dtype: {torch_dtype}")
35
+
36
+ # --- FIX for Model Component Mismatch ---
37
+ # 1. Manually load the correct GPT2 model variant (GPT2LMHeadModel).
38
+ # The sub-model 'gpt2' is used by audioldm2 for prompt understanding.
39
+ logger.info("Pre-loading the correct language model component (GPT2LMHeadModel)...")
40
+ language_model = GPT2LMHeadModel.from_pretrained(
41
+ "openai-community/gpt2", cache_dir=cache_dir
42
+ ).to(device)
43
+ logger.info("Language model component loaded successfully.")
44
+
45
+ # 2. Load the main pipeline, injecting our pre-loaded component.
46
+ # This forces the pipeline to use the correct model and avoid the AttributeError.
47
  pipeline = AudioLDM2Pipeline.from_pretrained(
48
+ MODEL_REPO,
49
  torch_dtype=torch_dtype,
50
+ cache_dir=cache_dir,
51
+ language_model=language_model, # <-- INJECT THE CORRECT COMPONENT
52
  )
53
  pipeline = pipeline.to(device)
54
  logger.info("Model loaded successfully and moved to device.")
55
  except Exception as e:
56
  logger.error(f"Fatal error during model loading: {e}", exc_info=True)
57
  # If the model fails to load, the 'pipeline' variable will remain None.
 
58
 
59
  # --- CORS Middleware ---
 
60
  app.add_middleware(
61
  CORSMiddleware,
62
+ allow_origins=["*"],
63
  allow_credentials=True,
64
+ allow_methods=["*"],
65
+ allow_headers=["*"],
66
  )
67
 
68
  # --- API Endpoints ---
69
  @app.post("/generate-audio")
70
  async def generate_audio_endpoint(payload: dict):
 
 
 
71
  if pipeline is None:
72
  logger.error("Request received, but model is not loaded.")
73
+ raise HTTPException(status_code=503, detail="Model is not available or failed to load. Please check server logs.")
74
 
75
  prompt = payload.get("prompt")
76
  if not prompt:
 
79
  try:
80
  logger.info(f"Generating audio for prompt: '{prompt}'")
81
 
 
82
  audio = pipeline(
83
  prompt,
84
+ negative_prompt="Low quality, noisy, muffled, mono",
85
+ num_inference_steps=200,
86
+ audio_length_in_s=2.5
87
  ).audios[0]
88
 
89
+ sample_rate = 16000
 
 
 
 
90
  audio_int16 = np.int16(audio * 32767)
 
 
91
  wav_file_in_memory = io.BytesIO()
92
  write_wav(wav_file_in_memory, sample_rate, audio_int16)
93
+ wav_file_in_memory.seek(0)
94
 
95
  safe_filename = "".join(c for c in prompt if c.isalnum() or c in (' ', '_')).rstrip()[:50] + ".wav"
96
 
 
107
 
108
  @app.get("/")
109
  def read_root():
 
110
  return {"message": "Decent Sampler Audio Generation API is running."}