File size: 4,399 Bytes
dbe9059 76fbb44 dbe9059 76fbb44 dbe9059 76fbb44 dbe9059 971a67e 78316ec 971a67e d973dd8 78316ec ad8da74 78316ec ad8da74 78316ec ad8da74 d973dd8 78316ec d973dd8 78316ec d973dd8 78316ec d973dd8 78316ec d973dd8 971a67e d973dd8 971a67e dbe9059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import joblib
import pandas as pd
from flask import Flask, request, jsonify
from utils.validation import validate_and_prepare_input, InputValidationError
# Initialize Flask app with a name
pred_mainteanance_api = Flask ("Engine Maintenance Predictor")
# Load the trained churn prediction model
model = joblib.load ("best_eng_fail_pred_model.joblib")
# Define a route for the home page
@pred_mainteanance_api.get ('/')
def home ():
return "Welcome to the Engine Maintenance Prediction!"
# Define an endpoint to predict sales for Super Kart
@pred_mainteanance_api.post ('/v1/EngPredMaintenance')
def predict_need_maintenance ():
# Get JSON data from the request
engine_sensor_inputs = request.get_json ()
# validate request (json)
# if input is valid - return prediction
# in case of error - return appropriate error
try:
input_json = request.get_json()
input_df = pd.DataFrame([input_json])
validated_df = validate_and_prepare_input(input_df, model)
prediction = model.predict(validated_df)[0]
return jsonify({
"status": "success",
"prediction": int(prediction)
})
except InputValidationError as e:
return jsonify({
"status": "error",
"error_type": "validation_error",
"message": str(e)
}), 400
except Exception as e:
return jsonify({
"status": "error",
"error_type": "internal_error",
"message": "Unexpected server error"
}), 500
# Define an endpoint to predict sales for Super Kart
@pred_mainteanance_api.post ('/v1/EngPredMaintenanceForBatch')
def predict_need_maintenance_for_batch ():
# validate request (json)
# if input is valid - return prediction
# in case of error - return appropriate error
try:
# Get the uploaded CSV file from the request
file = request.files.get('file')
if file is None:
return jsonify({
"status": "error",
"error_type": "input_error",
"message": "File not provided"
}), 400
if file.filename == "":
return jsonify({
"status": "error",
"error_type": "input_error",
"message": "No file selected"
}), 400
# Read the file into a DataFrame
input_df = pd.read_csv (file)
if input_df.empty:
return jsonify({
"status": "error",
"error_type": "input_error",
"message": "Uploaded file is empty"
}), 400
# Process the data to clean up and make it ready for prediction
# mostly we will use the file with same format as given in problem statement for batch prediction
# remove/drop engine condition column if present
input_df.drop(columns=['Engine Condition'], inplace=True, errors='ignore')
# update column names to replace spaces with underscore
input_df.columns = input_df.columns.str.replace(' ', '_')
# Convert int → float
int_columns = input_df.select_dtypes(include=['int64']).columns
input_df[int_columns] = input_df[int_columns].astype('float64')
# Validate entire batch
validated_df = validate_and_prepare_input(input_df, model)
# predict for given input
predictions = model.predict(validated_df)
# Convert numpy array → Python list
prediction_list = predictions.tolist()
return jsonify({
"status": "success", # overall batch status
"total_records": len(prediction_list),
"predictions": prediction_list, # simple list version
})
except InputValidationError as e:
return jsonify({
"status": "error",
"error_type": "validation_error",
"message": str(e)
}), 400
except Exception as e:
return jsonify({
"status": "error",
"error_type": "internal_error",
"message": "Unexpected server error"
}), 500
# Run the Flask app
if __name__ == "__main__":
import os
port = int (os.environ.get("PORT", 7860))
pred_mainteanance_api.run(host="0.0.0.0", port=port)
|