alonb19 commited on
Commit
6c6da3d
verified
1 Parent(s): ec1cc7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -24
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from transformers import pipeline, AutoProcessor, AutoModelForAudioClassification
4
  import tempfile
5
  import os
6
  import uvicorn
@@ -11,9 +11,11 @@ from datetime import datetime
11
  import torch
12
  from contextlib import asynccontextmanager
13
 
14
- # Configurar cache de Hugging Face
15
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
16
  os.environ['HF_HOME'] = '/tmp/huggingface'
 
 
17
 
18
  # Configurar logging
19
  logging.basicConfig(
@@ -34,9 +36,10 @@ async def load_model():
34
  # Crear directorios de cache
35
  os.makedirs('/tmp/transformers_cache', exist_ok=True)
36
  os.makedirs('/tmp/huggingface', exist_ok=True)
 
37
 
38
- # Usar modelo m谩s simple y confiable
39
- model_name = "superb/hubert-base-superb-ic"
40
 
41
  logger.info(f"Cargando modelo: {model_name}")
42
 
@@ -51,20 +54,7 @@ async def load_model():
51
 
52
  except Exception as e:
53
  logger.error(f"Error cargando modelo: {e}")
54
-
55
- # Fallback a modelo b谩sico
56
- try:
57
- logger.info("Intentando modelo alternativo...")
58
- classifier = pipeline(
59
- "audio-classification",
60
- model="facebook/wav2vec2-base",
61
- device=-1,
62
- return_all_scores=True
63
- )
64
- logger.info("Modelo alternativo cargado")
65
- except Exception as e2:
66
- logger.error(f"Error con modelo alternativo: {e2}")
67
- classifier = None
68
 
69
  async def cleanup_model():
70
  """Limpiar recursos"""
@@ -121,6 +111,24 @@ async def health_check():
121
  "cache_dir": "/tmp/transformers_cache"
122
  }
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  @app.post("/detect")
125
  async def detect_instrument(audio: UploadFile = File(...)):
126
  """Detectar instrumentos musicales en archivo de audio"""
@@ -161,12 +169,8 @@ async def detect_instrument(audio: UploadFile = File(...)):
161
  try:
162
  logger.info("Cargando audio...")
163
 
164
- # Cargar audio
165
- audio_data, sample_rate = librosa.load(
166
- temp_path,
167
- sr=16000,
168
- mono=True
169
- )
170
 
171
  duration = len(audio_data) / sample_rate
172
  logger.info(f"Audio: {duration:.2f}s, {sample_rate}Hz")
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from transformers import pipeline
4
  import tempfile
5
  import os
6
  import uvicorn
 
11
  import torch
12
  from contextlib import asynccontextmanager
13
 
14
+ # Configurar cache y deshabilitar numba cache
15
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
16
  os.environ['HF_HOME'] = '/tmp/huggingface'
17
+ os.environ['NUMBA_CACHE_DIR'] = '/tmp/numba_cache'
18
+ os.environ['NUMBA_DISABLE_JIT'] = '1' # Deshabilitar JIT de numba
19
 
20
  # Configurar logging
21
  logging.basicConfig(
 
36
  # Crear directorios de cache
37
  os.makedirs('/tmp/transformers_cache', exist_ok=True)
38
  os.makedirs('/tmp/huggingface', exist_ok=True)
39
+ os.makedirs('/tmp/numba_cache', exist_ok=True)
40
 
41
+ # Usar modelo m谩s simple para audio
42
+ model_name = "facebook/wav2vec2-base-960h"
43
 
44
  logger.info(f"Cargando modelo: {model_name}")
45
 
 
54
 
55
  except Exception as e:
56
  logger.error(f"Error cargando modelo: {e}")
57
+ classifier = None
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  async def cleanup_model():
60
  """Limpiar recursos"""
 
111
  "cache_dir": "/tmp/transformers_cache"
112
  }
113
 
114
+ def load_audio_simple(file_path):
115
+ """Cargar audio sin usar funciones complejas de librosa"""
116
+ try:
117
+ # Cargar audio de forma simple
118
+ y, sr = librosa.load(file_path, sr=16000, mono=True)
119
+ return y, sr
120
+ except Exception as e:
121
+ logger.error(f"Error con librosa.load: {e}")
122
+ # Fallback usando soundfile
123
+ import soundfile as sf
124
+ y, sr = sf.read(file_path)
125
+ if sr != 16000:
126
+ # Resample simple
127
+ from scipy import signal
128
+ y = signal.resample(y, int(len(y) * 16000 / sr))
129
+ sr = 16000
130
+ return y, sr
131
+
132
  @app.post("/detect")
133
  async def detect_instrument(audio: UploadFile = File(...)):
134
  """Detectar instrumentos musicales en archivo de audio"""
 
169
  try:
170
  logger.info("Cargando audio...")
171
 
172
+ # Cargar audio con funci贸n simple
173
+ audio_data, sample_rate = load_audio_simple(temp_path)
 
 
 
 
174
 
175
  duration = len(audio_data) / sample_rate
176
  logger.info(f"Audio: {duration:.2f}s, {sample_rate}Hz")