hswift commited on
Commit
375db4a
·
verified ·
1 Parent(s): 4c26d1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -1,4 +1,19 @@
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import tempfile
4
  from fastapi import FastAPI, HTTPException
@@ -8,7 +23,6 @@ from pydantic import BaseModel
8
  from diffusers import AudioLDMPipeline
9
  from scipy.io.wavfile import write as write_wav
10
  import numpy as np
11
- import logging
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
@@ -35,26 +49,22 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
35
  torch_dtype = torch.float16 if "cuda" in device else torch.float32
36
 
37
  logger.info(f"Using device: {device} with dtype: {torch_dtype}")
38
-
39
- # --- FIX FOR PERMISSION ERROR ---
40
- # The environment we're running in doesn't allow writing to the default '/.cache' directory.
41
- # We explicitly define a writable directory within '/tmp' for the model cache.
42
- CACHE_DIR = "/tmp/huggingface_cache"
43
- os.makedirs(CACHE_DIR, exist_ok=True)
44
  logger.info(f"Using model cache directory: {CACHE_DIR}")
45
 
 
46
  try:
47
  # Use the stable, recommended model
48
  repo_id = "cvssp/audioldm-s-full-v2"
49
  pipe = AudioLDMPipeline.from_pretrained(
50
  repo_id,
51
  torch_dtype=torch_dtype,
52
- cache_dir=CACHE_DIR # Pass the writable cache directory to the loader
 
53
  )
54
  pipe = pipe.to(device)
55
  logger.info(f"Successfully loaded model: {repo_id}")
56
  except Exception as e:
57
- logger.error(f"Failed to load the model: {e}")
58
  pipe = None # Ensure pipe is None if loading fails
59
 
60
  # --- API Endpoint ---
@@ -66,10 +76,8 @@ async def generate_audio_endpoint(request: AudioRequest):
66
  prompt = request.prompt
67
  logger.info(f"Generating audio for prompt: '{prompt}'")
68
 
69
- # Use a temporary file to store the generated audio
70
  temp_file_path = ""
71
  try:
72
- # Generate the audio waveform
73
  audio = pipe(
74
  prompt,
75
  num_inference_steps=200,
@@ -79,28 +87,21 @@ async def generate_audio_endpoint(request: AudioRequest):
79
 
80
  sample_rate = 16000
81
 
82
- # Create a temporary file to save the audio
83
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
84
  temp_file_path = temp_file.name
85
-
86
- # Normalize and convert to 16-bit integer format for WAV
87
  audio_int16 = (audio * 32767).astype(np.int16)
88
-
89
- # Write the WAV file
90
  write_wav(temp_file_path, sample_rate, audio_int16)
91
  logger.info(f"Audio saved to temporary file: {temp_file_path}")
92
 
93
- # Return the audio file as a response.
94
  return FileResponse(
95
  path=temp_file_path,
96
  media_type='audio/wav',
97
  filename=f"{prompt[:50].replace(' ', '_')}.wav",
98
- background=os.remove(temp_file_path) # Clean up the file after sending
99
  )
100
 
101
  except Exception as e:
102
  logger.error(f"Error during audio generation for prompt '{prompt}': {e}", exc_info=True)
103
- # Clean up the temp file if it was created before the error
104
  if temp_file_path and os.path.exists(temp_file_path):
105
  os.remove(temp_file_path)
106
  raise HTTPException(status_code=500, detail=str(e))
 
1
  import os
2
+ import logging
3
+
4
+ # --- FIX FOR ALL PERMISSION ERRORS ---
5
+ # Set environment variables BEFORE importing torch or diffusers.
6
+ # This forces all underlying libraries (huggingface_hub, torch, etc.)
7
+ # to use a writable directory inside /tmp, avoiding any permission errors.
8
+ CACHE_DIR = "/tmp/huggingface_cache"
9
+ os.environ['HF_HOME'] = CACHE_DIR
10
+ os.environ['HF_HUB_CACHE'] = os.path.join(CACHE_DIR, 'hub')
11
+ os.environ['TORCH_HOME'] = os.path.join(CACHE_DIR, 'torch')
12
+ os.makedirs(os.path.join(CACHE_DIR, 'hub'), exist_ok=True)
13
+ os.makedirs(os.path.join(CACHE_DIR, 'torch'), exist_ok=True)
14
+
15
+
16
+ # Now it's safe to import the other libraries
17
  import torch
18
  import tempfile
19
  from fastapi import FastAPI, HTTPException
 
23
  from diffusers import AudioLDMPipeline
24
  from scipy.io.wavfile import write as write_wav
25
  import numpy as np
 
26
 
27
  # Configure logging
28
  logging.basicConfig(level=logging.INFO)
 
49
  torch_dtype = torch.float16 if "cuda" in device else torch.float32
50
 
51
  logger.info(f"Using device: {device} with dtype: {torch_dtype}")
 
 
 
 
 
 
52
  logger.info(f"Using model cache directory: {CACHE_DIR}")
53
 
54
+ pipe = None
55
  try:
56
  # Use the stable, recommended model
57
  repo_id = "cvssp/audioldm-s-full-v2"
58
  pipe = AudioLDMPipeline.from_pretrained(
59
  repo_id,
60
  torch_dtype=torch_dtype,
61
+ # cache_dir is still good practice but the environment variables are the real fix
62
+ cache_dir=CACHE_DIR
63
  )
64
  pipe = pipe.to(device)
65
  logger.info(f"Successfully loaded model: {repo_id}")
66
  except Exception as e:
67
+ logger.error(f"Failed to load the model: {e}", exc_info=True)
68
  pipe = None # Ensure pipe is None if loading fails
69
 
70
  # --- API Endpoint ---
 
76
  prompt = request.prompt
77
  logger.info(f"Generating audio for prompt: '{prompt}'")
78
 
 
79
  temp_file_path = ""
80
  try:
 
81
  audio = pipe(
82
  prompt,
83
  num_inference_steps=200,
 
87
 
88
  sample_rate = 16000
89
 
 
90
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
91
  temp_file_path = temp_file.name
 
 
92
  audio_int16 = (audio * 32767).astype(np.int16)
 
 
93
  write_wav(temp_file_path, sample_rate, audio_int16)
94
  logger.info(f"Audio saved to temporary file: {temp_file_path}")
95
 
 
96
  return FileResponse(
97
  path=temp_file_path,
98
  media_type='audio/wav',
99
  filename=f"{prompt[:50].replace(' ', '_')}.wav",
100
+ background=os.remove(temp_file_path)
101
  )
102
 
103
  except Exception as e:
104
  logger.error(f"Error during audio generation for prompt '{prompt}': {e}", exc_info=True)
 
105
  if temp_file_path and os.path.exists(temp_file_path):
106
  os.remove(temp_file_path)
107
  raise HTTPException(status_code=500, detail=str(e))