Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import tempfile | |
| import os | |
| import base64 | |
| from typing import Dict, Any | |
| import shutil | |
| import json | |
| import numpy as np | |
| from app import get_all_model_predictions, create_confidence_chart, create_voting_chart | |
| def convert_numpy_types(obj): | |
| """Convert numpy types to Python native types""" | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| elif isinstance(obj, np.floating): | |
| return float(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| elif isinstance(obj, dict): | |
| return {key: convert_numpy_types(value) for key, value in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return [convert_numpy_types(item) for item in obj] | |
| return obj | |
| app = FastAPI(title="Engine Sound Classifier") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files | |
| os.makedirs("static", exist_ok=True) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| async def root(): | |
| """Serve the index.html file""" | |
| with open("index.html", "r") as f: | |
| return f.read() | |
| async def predict_audio(file: UploadFile = File(...)): | |
| """Process uploaded audio file and return predictions""" | |
| # Save uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp: | |
| shutil.copyfileobj(file.file, tmp) | |
| tmp_path = tmp.name | |
| try: | |
| # Get predictions from all models | |
| results, best_model, best_prediction, highest_confidence = get_all_model_predictions(tmp_path) | |
| # Convert numpy types to Python native types | |
| results = convert_numpy_types(results) | |
| highest_confidence = float(highest_confidence) if isinstance(highest_confidence, np.floating) else highest_confidence | |
| # Create visualization charts | |
| confidence_chart_path = create_confidence_chart(results, best_model) | |
| voting_chart_path = create_voting_chart(results) | |
| # Read and encode charts if they exist | |
| confidence_chart_data = None | |
| voting_chart_data = None | |
| if confidence_chart_path and os.path.exists(confidence_chart_path): | |
| with open(confidence_chart_path, "rb") as img_file: | |
| confidence_chart_data = base64.b64encode(img_file.read()).decode('utf-8') | |
| if voting_chart_path and os.path.exists(voting_chart_path): | |
| with open(voting_chart_path, "rb") as img_file: | |
| voting_chart_data = base64.b64encode(img_file.read()).decode('utf-8') | |
| # Prepare response | |
| response = { | |
| "predictions": results, | |
| "best_model": best_model, | |
| "best_prediction": best_prediction, | |
| "confidence": highest_confidence, | |
| "confidence_chart": confidence_chart_data, | |
| "voting_chart": voting_chart_data | |
| } | |
| return response | |
| finally: | |
| # Clean up temporary files | |
| os.unlink(tmp_path) | |
| for chart_path in [confidence_chart_path, voting_chart_path]: | |
| if chart_path and os.path.exists(chart_path): | |
| os.unlink(chart_path) | |
| async def get_models(): | |
| """Return a list of available models""" | |
| models_dir = 'models' | |
| model_names = { | |
| 'lr_model.joblib': 'Logistic Regression', | |
| 'nn_model.joblib': 'Neural Network', | |
| 'rf_model.joblib': 'Random Forest', | |
| 'svm_model.joblib': 'Support Vector Machine', | |
| 'xgb_model.joblib': 'XGBoost' | |
| } | |
| available_models = [] | |
| if os.path.exists(models_dir): | |
| model_files = [f for f in os.listdir(models_dir) if f.endswith('_model.joblib')] | |
| available_models = [model_names[file] for file in model_files if file in model_names] | |
| return {"models": available_models} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |