sample-noise / api.py
Kumar Shubham
Initial push
b7becdf
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
import tempfile
import os
import base64
from typing import Dict, Any
import shutil
import json
import numpy as np
from app import get_all_model_predictions, create_confidence_chart, create_voting_chart
def convert_numpy_types(obj):
"""Convert numpy types to Python native types"""
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, dict):
return {key: convert_numpy_types(value) for key, value in obj.items()}
elif isinstance(obj, (list, tuple)):
return [convert_numpy_types(item) for item in obj]
return obj
app = FastAPI(title="Engine Sound Classifier")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
os.makedirs("static", exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the index.html file"""
with open("index.html", "r") as f:
return f.read()
@app.post("/api/predict")
async def predict_audio(file: UploadFile = File(...)):
"""Process uploaded audio file and return predictions"""
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
shutil.copyfileobj(file.file, tmp)
tmp_path = tmp.name
try:
# Get predictions from all models
results, best_model, best_prediction, highest_confidence = get_all_model_predictions(tmp_path)
# Convert numpy types to Python native types
results = convert_numpy_types(results)
highest_confidence = float(highest_confidence) if isinstance(highest_confidence, np.floating) else highest_confidence
# Create visualization charts
confidence_chart_path = create_confidence_chart(results, best_model)
voting_chart_path = create_voting_chart(results)
# Read and encode charts if they exist
confidence_chart_data = None
voting_chart_data = None
if confidence_chart_path and os.path.exists(confidence_chart_path):
with open(confidence_chart_path, "rb") as img_file:
confidence_chart_data = base64.b64encode(img_file.read()).decode('utf-8')
if voting_chart_path and os.path.exists(voting_chart_path):
with open(voting_chart_path, "rb") as img_file:
voting_chart_data = base64.b64encode(img_file.read()).decode('utf-8')
# Prepare response
response = {
"predictions": results,
"best_model": best_model,
"best_prediction": best_prediction,
"confidence": highest_confidence,
"confidence_chart": confidence_chart_data,
"voting_chart": voting_chart_data
}
return response
finally:
# Clean up temporary files
os.unlink(tmp_path)
for chart_path in [confidence_chart_path, voting_chart_path]:
if chart_path and os.path.exists(chart_path):
os.unlink(chart_path)
@app.get("/api/models")
async def get_models():
"""Return a list of available models"""
models_dir = 'models'
model_names = {
'lr_model.joblib': 'Logistic Regression',
'nn_model.joblib': 'Neural Network',
'rf_model.joblib': 'Random Forest',
'svm_model.joblib': 'Support Vector Machine',
'xgb_model.joblib': 'XGBoost'
}
available_models = []
if os.path.exists(models_dir):
model_files = [f for f in os.listdir(models_dir) if f.endswith('_model.joblib')]
available_models = [model_names[file] for file in model_files if file in model_names]
return {"models": available_models}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)