hswift commited on
Commit
4c26d1d
·
verified ·
1 Parent(s): 57f2351

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -31,16 +31,26 @@ class AudioRequest(BaseModel):
31
  prompt: str
32
 
33
  # --- Model Loading ---
34
- # This section runs once when the application starts.
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
37
 
38
  logger.info(f"Using device: {device} with dtype: {torch_dtype}")
39
 
 
 
 
 
 
 
 
40
  try:
41
  # Use the stable, recommended model
42
  repo_id = "cvssp/audioldm-s-full-v2"
43
- pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype)
 
 
 
 
44
  pipe = pipe.to(device)
45
  logger.info(f"Successfully loaded model: {repo_id}")
46
  except Exception as e:
@@ -60,7 +70,6 @@ async def generate_audio_endpoint(request: AudioRequest):
60
  temp_file_path = ""
61
  try:
62
  # Generate the audio waveform
63
- # These are good parameters for this model
64
  audio = pipe(
65
  prompt,
66
  num_inference_steps=200,
@@ -68,9 +77,6 @@ async def generate_audio_endpoint(request: AudioRequest):
68
  guidance_scale=7.0
69
  ).audios[0]
70
 
71
- # --- THIS IS THE FIX ---
72
- # The cvssp/audioldm-s-full-v2 model has a fixed sample rate of 16000 Hz.
73
- # We set it directly here instead of trying to read it from a config that no longer exists.
74
  sample_rate = 16000
75
 
76
  # Create a temporary file to save the audio
@@ -85,7 +91,6 @@ async def generate_audio_endpoint(request: AudioRequest):
85
  logger.info(f"Audio saved to temporary file: {temp_file_path}")
86
 
87
  # Return the audio file as a response.
88
- # The file will be deleted after being sent.
89
  return FileResponse(
90
  path=temp_file_path,
91
  media_type='audio/wav',
 
31
  prompt: str
32
 
33
  # --- Model Loading ---
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ torch_dtype = torch.float16 if "cuda" in device else torch.float32
36
 
37
  logger.info(f"Using device: {device} with dtype: {torch_dtype}")
38
 
39
+ # --- FIX FOR PERMISSION ERROR ---
40
+ # The environment we're running in doesn't allow writing to the default '/.cache' directory.
41
+ # We explicitly define a writable directory within '/tmp' for the model cache.
42
+ CACHE_DIR = "/tmp/huggingface_cache"
43
+ os.makedirs(CACHE_DIR, exist_ok=True)
44
+ logger.info(f"Using model cache directory: {CACHE_DIR}")
45
+
46
  try:
47
  # Use the stable, recommended model
48
  repo_id = "cvssp/audioldm-s-full-v2"
49
+ pipe = AudioLDMPipeline.from_pretrained(
50
+ repo_id,
51
+ torch_dtype=torch_dtype,
52
+ cache_dir=CACHE_DIR # Pass the writable cache directory to the loader
53
+ )
54
  pipe = pipe.to(device)
55
  logger.info(f"Successfully loaded model: {repo_id}")
56
  except Exception as e:
 
70
  temp_file_path = ""
71
  try:
72
  # Generate the audio waveform
 
73
  audio = pipe(
74
  prompt,
75
  num_inference_steps=200,
 
77
  guidance_scale=7.0
78
  ).audios[0]
79
 
 
 
 
80
  sample_rate = 16000
81
 
82
  # Create a temporary file to save the audio
 
91
  logger.info(f"Audio saved to temporary file: {temp_file_path}")
92
 
93
  # Return the audio file as a response.
 
94
  return FileResponse(
95
  path=temp_file_path,
96
  media_type='audio/wav',