BettaVox / app.py
Manubett1234's picture
Update app.py
b690425 verified
import os
import pickle
import joblib
import numpy as np
import traceback
from flask_cors import CORS
from flask import Flask, request, render_template, jsonify
from werkzeug.utils import secure_filename
from extract import extract_features # Import feature extractor
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Allow all cross-origin requests
# Set upload folder and allowed file types
UPLOAD_FOLDER = "/tmp/uploads" # Use a temporary directory for compatibility
ALLOWED_EXTENSIONS = {"wav", "mp3", "ogg", "m4a"}
# Ensure upload folder exists
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config["UPLOAD_FOLDER"] = UPLOAD_FOLDER
# Load trained model, scaler, and feature list
MODEL_DIR = os.path.join(os.getcwd(), "models")
model_path = os.path.join(MODEL_DIR, "gender_model_svm.pkl")
scaler_path = os.path.join(MODEL_DIR, "scaler_gender_model_svm.pkl")
feature_list_path = os.path.join(MODEL_DIR, "feature_list.pkl")
# Load model, scaler, and feature list
model = joblib.load(model_path)
scaler = joblib.load(scaler_path)
with open(feature_list_path, "rb") as f:
feature_list = pickle.load(f)
EXPECTED_FEATURE_COUNT = len(feature_list) # Ensure features match model training
print(f"βœ… Model, Scaler, and Feature List Loaded Successfully! (Expecting {EXPECTED_FEATURE_COUNT} features)")
# Function to check valid file extensions
def allowed_file(filename):
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
# Route to render the HTML interface
@app.route("/")
def index():
return render_template("index.html")
# Route to handle file upload and prediction
@app.route("/predict", methods=["POST"])
def predict():
try:
if "audio" not in request.files:
print("❌ No file uploaded")
return jsonify({"error": "No file uploaded"}), 400
file = request.files["audio"]
print(f"πŸ“₯ Received file: {file.filename}, Type: {file.content_type}") # βœ… Debugging line
if file.filename == "":
return jsonify({"error": "No selected file"}), 400
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config["UPLOAD_FOLDER"], filename)
file.save(filepath)
print(f"🟒 Processing file: {filename}")
# Extract features
features = extract_features(filepath)
if features is None:
print("❌ Feature extraction failed")
return jsonify({"error": "Feature extraction failed"}), 500
print(f"🟒 Extracted {len(features)} features")
# Validate feature count
if len(features) != EXPECTED_FEATURE_COUNT:
print(f"❌ Feature count mismatch! Extracted: {len(features)}, Expected: {EXPECTED_FEATURE_COUNT}")
return jsonify({"error": f"Feature count mismatch. Expected {EXPECTED_FEATURE_COUNT}, got {len(features)}"}), 500
# Scale features
features_scaled = scaler.transform([features])
print("🟒 Features scaled successfully")
# Predict gender
prediction = model.predict(features_scaled)[0]
confidence = model.predict_proba(features_scaled)[0]
print("🟒 Prediction completed")
# Format response
result = {
"gender": "Female" if prediction == 1 else "Male",
"confidence": float(max(confidence)),
"age_group": "Unknown" # Temporary placeholder
}
print(f"βœ… Prediction Result: {result}")
return jsonify(result)
except Exception as e:
print(f"❌ ERROR: {e}")
traceback.print_exc() # πŸ”΄ Print full error traceback
return jsonify({"error": str(e)}), 500
finally:
# Cleanup: Remove the file to free storage
if os.path.exists(filepath):
os.remove(filepath)
print(f"πŸ—‘οΈ Deleted temp file: {filepath}")
# Run Flask app
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=False)