# app.py - Versión corregida con padding from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from transformers import pipeline import tempfile import os import uvicorn import librosa import soundfile as sf import numpy as np # Crear app FastAPI app = FastAPI( title="Musical Instrument Detection API", description="API para detectar instrumentos musicales en audio", version="1.0.0" ) # Configurar CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Variable global para el modelo classifier = None @app.on_event("startup") async def startup_event(): """Cargar modelo al iniciar la aplicación""" global classifier try: print("🔄 Cargando modelo...") # Configurar pipeline con padding y truncación classifier = pipeline( "audio-classification", model="Janiopi/detector_de_instrumentos_v1", feature_extractor_kwargs={ "padding": True, "truncation": True, "max_length": 240000, # 15 segundos a 16kHz "return_tensors": "pt" } ) print("✅ Modelo cargado exitosamente con configuración de padding") except Exception as e: print(f"❌ Error cargando modelo: {e}") classifier = None @app.get("/", response_class=HTMLResponse) async def root(): """Página principal con documentación""" html_content = """
API para detectar instrumentos musicales (Guitarra, Piano, Batería)
Verificar estado del servicio y modelo
Detectar instrumentos en archivo de audio
Content-Type: multipart/form-data
Parámetro: audio (archivo)
Documentación interactiva de la API (Swagger)
POST https://janiopi-musical-detector-api.hf.space/detect
Content-Type: multipart/form-data
Body: audio file (campo "audio")
Respuesta:
{
"success": true,
"results": [
{
"label": "Sound_Guitar",
"score": 0.8547
}
],
"filename": "audio.wav"
}
"""
return html_content
@app.get("/health")
async def health_check():
"""Verificar estado del servicio"""
return {
"status": "online",
"model_loaded": classifier is not None,
"message": "API funcionando correctamente",
"model_info": "Janiopi/detector_de_instrumentos_v1",
"supported_instruments": ["Guitar", "Piano", "Drum"],
"max_duration_seconds": 15,
"sample_rate": 16000
}
@app.post("/detect")
async def detect_instrument(audio: UploadFile = File(...)):
"""
Detectar instrumentos musicales en archivo de audio
"""
try:
if classifier is None:
raise HTTPException(
status_code=503,
detail="Modelo no disponible. Intenta más tarde."
)
print(f"📁 Procesando: {audio.filename} ({audio.content_type})")
# Leer contenido
content = await audio.read()
print(f"📏 Tamaño: {len(content)} bytes")
# Crear archivo temporal
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
temp_file.write(content)
temp_path = temp_file.name
try:
print("🎵 Cargando audio con librosa...")
# Cargar audio con librosa (maneja múltiples formatos)
audio_data, sample_rate = librosa.load(temp_path, sr=16000)
print(f"🔊 Audio cargado: {len(audio_data)} samples a {sample_rate}Hz")
print(f"⏱️ Duración: {len(audio_data)/sample_rate:.2f} segundos")
# Verificar duración mínima
if len(audio_data) < 1600: # Menos de 0.1 segundos
raise ValueError("Audio demasiado corto (mínimo 0.1 segundos)")
# Truncar a máximo 15 segundos
max_samples = 15 * 16000
if len(audio_data) > max_samples:
audio_data = audio_data[:max_samples]
print(f"🔄 Audio truncado a 15 segundos")
# Asegurar que el audio tenga el formato correcto
audio_data = np.array(audio_data, dtype=np.float32)
# Guardar como WAV temporal para el modelo
temp_wav_path = temp_path.replace('.wav', '_processed.wav')
sf.write(temp_wav_path, audio_data, sample_rate)
print(f"💾 Audio guardado como: {temp_wav_path}")
print("🤖 Ejecutando modelo...")
# Procesar con el modelo
results = classifier(temp_wav_path)
print(f"🎯 Resultados raw: {results}")
# Limpiar archivo WAV procesado
if os.path.exists(temp_wav_path):
os.unlink(temp_wav_path)
# Formatear resultados
formatted_results = []
for result in results:
formatted_results.append({
"label": result["label"],
"score": round(float(result["score"]), 4)
})
# Ordenar por score descendente
formatted_results.sort(key=lambda x: x["score"], reverse=True)
print(f"✅ Resultados formateados: {formatted_results}")
return {
"success": True,
"results": formatted_results,
"filename": audio.filename,
"audio_info": {
"samples": len(audio_data),
"sample_rate": sample_rate,
"duration_seconds": round(len(audio_data) / sample_rate, 2),
"processed_size_bytes": len(content)
}
}
finally:
# Limpiar archivo temporal original
if os.path.exists(temp_path):
os.unlink(temp_path)
except HTTPException:
raise
except Exception as e:
print(f"❌ Error inesperado: {e}")
import traceback
traceback.print_exc()
# Mensajes de error más específicos
error_msg = str(e)
if "Unable to create tensor" in error_msg:
detail = "Error de formato de audio. Intenta con un archivo WAV de mejor calidad."
elif "too short" in error_msg.lower():
detail = "Audio demasiado corto. Graba al menos 1 segundo."
elif "padding" in error_msg:
detail = "Error de procesamiento de audio. Intenta con un archivo diferente."
else:
detail = f"Error procesando audio: {error_msg}"
raise HTTPException(status_code=500, detail=detail)
@app.get("/test")
async def test_endpoint():
"""Endpoint de prueba para verificar conectividad"""
return {
"message": "API funcionando",
"timestamp": "2025-01-16",
"test": "ok"
}
# Ejecutar la aplicación
if __name__ == "__main__":
print("🚀 Iniciando Musical Instrument Detection API...")
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)