File size: 4,259 Bytes
b7becdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
import tempfile
import os
import base64
from typing import Dict, Any
import shutil
import json
import numpy as np

from app import get_all_model_predictions, create_confidence_chart, create_voting_chart

def convert_numpy_types(obj):
    """Convert numpy types to Python native types"""
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [convert_numpy_types(item) for item in obj]
    return obj

app = FastAPI(title="Engine Sound Classifier")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Mount static files
os.makedirs("static", exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.get("/", response_class=HTMLResponse)
async def root():
    """Serve the index.html file"""
    with open("index.html", "r") as f:
        return f.read()

@app.post("/api/predict")
async def predict_audio(file: UploadFile = File(...)):
    """Process uploaded audio file and return predictions"""
    # Save uploaded file temporarily
    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
        shutil.copyfileobj(file.file, tmp)
        tmp_path = tmp.name

    try:
        # Get predictions from all models
        results, best_model, best_prediction, highest_confidence = get_all_model_predictions(tmp_path)
        
        # Convert numpy types to Python native types
        results = convert_numpy_types(results)
        highest_confidence = float(highest_confidence) if isinstance(highest_confidence, np.floating) else highest_confidence
        
        # Create visualization charts
        confidence_chart_path = create_confidence_chart(results, best_model)
        voting_chart_path = create_voting_chart(results)
        
        # Read and encode charts if they exist
        confidence_chart_data = None
        voting_chart_data = None
        
        if confidence_chart_path and os.path.exists(confidence_chart_path):
            with open(confidence_chart_path, "rb") as img_file:
                confidence_chart_data = base64.b64encode(img_file.read()).decode('utf-8')
        
        if voting_chart_path and os.path.exists(voting_chart_path):
            with open(voting_chart_path, "rb") as img_file:
                voting_chart_data = base64.b64encode(img_file.read()).decode('utf-8')
        
        # Prepare response
        response = {
            "predictions": results,
            "best_model": best_model,
            "best_prediction": best_prediction,
            "confidence": highest_confidence,
            "confidence_chart": confidence_chart_data,
            "voting_chart": voting_chart_data
        }

        return response

    finally:
        # Clean up temporary files
        os.unlink(tmp_path)
        for chart_path in [confidence_chart_path, voting_chart_path]:
            if chart_path and os.path.exists(chart_path):
                os.unlink(chart_path)

@app.get("/api/models")
async def get_models():
    """Return a list of available models"""
    models_dir = 'models'
    model_names = {
        'lr_model.joblib': 'Logistic Regression',
        'nn_model.joblib': 'Neural Network',
        'rf_model.joblib': 'Random Forest',
        'svm_model.joblib': 'Support Vector Machine',
        'xgb_model.joblib': 'XGBoost'
    }
    
    available_models = []
    if os.path.exists(models_dir):
        model_files = [f for f in os.listdir(models_dir) if f.endswith('_model.joblib')]
        available_models = [model_names[file] for file in model_files if file in model_names]
    
    return {"models": available_models}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)