Janiopi's picture
Update app.py
64139ab verified
# app.py - Versión corregida con padding
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from transformers import pipeline
import tempfile
import os
import uvicorn
import librosa
import soundfile as sf
import numpy as np
# Crear app FastAPI
app = FastAPI(
title="Musical Instrument Detection API",
description="API para detectar instrumentos musicales en audio",
version="1.0.0"
)
# Configurar CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Variable global para el modelo
classifier = None
@app.on_event("startup")
async def startup_event():
"""Cargar modelo al iniciar la aplicación"""
global classifier
try:
print("🔄 Cargando modelo...")
# Configurar pipeline con padding y truncación
classifier = pipeline(
"audio-classification",
model="Janiopi/detector_de_instrumentos_v1",
feature_extractor_kwargs={
"padding": True,
"truncation": True,
"max_length": 240000, # 15 segundos a 16kHz
"return_tensors": "pt"
}
)
print("✅ Modelo cargado exitosamente con configuración de padding")
except Exception as e:
print(f"❌ Error cargando modelo: {e}")
classifier = None
@app.get("/", response_class=HTMLResponse)
async def root():
"""Página principal con documentación"""
html_content = """
<!DOCTYPE html>
<html>
<head>
<title>Musical Instrument Detection API</title>
<style>
body { font-family: Arial, sans-serif; margin: 40px; }
.endpoint { background: #f0f0f0; padding: 10px; margin: 10px 0; border-radius: 5px; }
.method { color: white; padding: 2px 8px; border-radius: 3px; font-weight: bold; }
.get { background: #61affe; }
.post { background: #49cc90; }
</style>
</head>
<body>
<h1>🎵 Musical Instrument Detection API</h1>
<p>API para detectar instrumentos musicales (Guitarra, Piano, Batería)</p>
<h2>📡 Endpoints Disponibles:</h2>
<div class="endpoint">
<span class="method get">GET</span> <strong>/health</strong>
<p>Verificar estado del servicio y modelo</p>
</div>
<div class="endpoint">
<span class="method post">POST</span> <strong>/detect</strong>
<p>Detectar instrumentos en archivo de audio</p>
<p><strong>Content-Type:</strong> multipart/form-data</p>
<p><strong>Parámetro:</strong> audio (archivo)</p>
</div>
<div class="endpoint">
<span class="method get">GET</span> <strong>/docs</strong>
<p>Documentación interactiva de la API (Swagger)</p>
</div>
<h2>📱 Uso desde Android:</h2>
<pre style="background: #f8f8f8; padding: 15px; border-radius: 5px;">
POST https://janiopi-musical-detector-api.hf.space/detect
Content-Type: multipart/form-data
Body: audio file (campo "audio")
Respuesta:
{
"success": true,
"results": [
{
"label": "Sound_Guitar",
"score": 0.8547
}
],
"filename": "audio.wav"
}
</pre>
</body>
</html>
"""
return html_content
@app.get("/health")
async def health_check():
"""Verificar estado del servicio"""
return {
"status": "online",
"model_loaded": classifier is not None,
"message": "API funcionando correctamente",
"model_info": "Janiopi/detector_de_instrumentos_v1",
"supported_instruments": ["Guitar", "Piano", "Drum"],
"max_duration_seconds": 15,
"sample_rate": 16000
}
@app.post("/detect")
async def detect_instrument(audio: UploadFile = File(...)):
"""
Detectar instrumentos musicales en archivo de audio
"""
try:
if classifier is None:
raise HTTPException(
status_code=503,
detail="Modelo no disponible. Intenta más tarde."
)
print(f"📁 Procesando: {audio.filename} ({audio.content_type})")
# Leer contenido
content = await audio.read()
print(f"📏 Tamaño: {len(content)} bytes")
# Crear archivo temporal
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
temp_file.write(content)
temp_path = temp_file.name
try:
print("🎵 Cargando audio con librosa...")
# Cargar audio con librosa (maneja múltiples formatos)
audio_data, sample_rate = librosa.load(temp_path, sr=16000)
print(f"🔊 Audio cargado: {len(audio_data)} samples a {sample_rate}Hz")
print(f"⏱️ Duración: {len(audio_data)/sample_rate:.2f} segundos")
# Verificar duración mínima
if len(audio_data) < 1600: # Menos de 0.1 segundos
raise ValueError("Audio demasiado corto (mínimo 0.1 segundos)")
# Truncar a máximo 15 segundos
max_samples = 15 * 16000
if len(audio_data) > max_samples:
audio_data = audio_data[:max_samples]
print(f"🔄 Audio truncado a 15 segundos")
# Asegurar que el audio tenga el formato correcto
audio_data = np.array(audio_data, dtype=np.float32)
# Guardar como WAV temporal para el modelo
temp_wav_path = temp_path.replace('.wav', '_processed.wav')
sf.write(temp_wav_path, audio_data, sample_rate)
print(f"💾 Audio guardado como: {temp_wav_path}")
print("🤖 Ejecutando modelo...")
# Procesar con el modelo
results = classifier(temp_wav_path)
print(f"🎯 Resultados raw: {results}")
# Limpiar archivo WAV procesado
if os.path.exists(temp_wav_path):
os.unlink(temp_wav_path)
# Formatear resultados
formatted_results = []
for result in results:
formatted_results.append({
"label": result["label"],
"score": round(float(result["score"]), 4)
})
# Ordenar por score descendente
formatted_results.sort(key=lambda x: x["score"], reverse=True)
print(f"✅ Resultados formateados: {formatted_results}")
return {
"success": True,
"results": formatted_results,
"filename": audio.filename,
"audio_info": {
"samples": len(audio_data),
"sample_rate": sample_rate,
"duration_seconds": round(len(audio_data) / sample_rate, 2),
"processed_size_bytes": len(content)
}
}
finally:
# Limpiar archivo temporal original
if os.path.exists(temp_path):
os.unlink(temp_path)
except HTTPException:
raise
except Exception as e:
print(f"❌ Error inesperado: {e}")
import traceback
traceback.print_exc()
# Mensajes de error más específicos
error_msg = str(e)
if "Unable to create tensor" in error_msg:
detail = "Error de formato de audio. Intenta con un archivo WAV de mejor calidad."
elif "too short" in error_msg.lower():
detail = "Audio demasiado corto. Graba al menos 1 segundo."
elif "padding" in error_msg:
detail = "Error de procesamiento de audio. Intenta con un archivo diferente."
else:
detail = f"Error procesando audio: {error_msg}"
raise HTTPException(status_code=500, detail=detail)
@app.get("/test")
async def test_endpoint():
"""Endpoint de prueba para verificar conectividad"""
return {
"message": "API funcionando",
"timestamp": "2025-01-16",
"test": "ok"
}
# Ejecutar la aplicación
if __name__ == "__main__":
print("🚀 Iniciando Musical Instrument Detection API...")
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)