superkart-api / app.py
itsjarvis's picture
Upload app.py
a9c1f61 verified
from flask import Flask, request, jsonify
import joblib
import numpy as np
import pandas as pd
from flask_cors import CORS
import logging
from datetime import datetime
import os
import traceback
# Initialize Flask app
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global variables for model and preprocessor
model = None
preprocessor = None
model_artifacts = None
def load_model():
"""Load the trained model and preprocessing artifacts."""
global model, preprocessor, model_artifacts
try:
model_path = 'superkart_sales_forecasting_model.joblib'
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return False
# Load model artifacts
model_artifacts = joblib.load(model_path)
model = model_artifacts['model']
preprocessor = model_artifacts['preprocessor']
logger.info(f"Model loaded successfully: {model_artifacts['model_name']}")
logger.info(f"Training date: {model_artifacts['training_date']}")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
def validate_input_data(data):
"""Validate input data for prediction."""
required_fields = [
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_MRP', 'Store_Size',
'Store_Location_City_Type', 'Store_Type', 'Store_Age'
]
# Check if all required fields are present
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return False, f"Missing required fields: {missing_fields}"
# Validate data types and ranges
try:
# Numerical validations
if not isinstance(data['Product_Weight'], (int, float)) or data['Product_Weight'] <= 0:
return False, "Product_Weight must be a positive number"
if not isinstance(data['Product_Allocated_Area'], (int, float)) or not (0 <= data['Product_Allocated_Area'] <= 1):
return False, "Product_Allocated_Area must be between 0 and 1"
if not isinstance(data['Product_MRP'], (int, float)) or data['Product_MRP'] <= 0:
return False, "Product_MRP must be a positive number"
if not isinstance(data['Store_Age'], (int, float)) or data['Store_Age'] < 0:
return False, "Store_Age must be a non-negative number"
# Categorical validations
valid_sugar_content = ['Low Sugar', 'Regular', 'No Sugar']
if data['Product_Sugar_Content'] not in valid_sugar_content:
return False, f"Product_Sugar_Content must be one of: {valid_sugar_content}"
valid_store_sizes = ['Small', 'Medium', 'High']
if data['Store_Size'] not in valid_store_sizes:
return False, f"Store_Size must be one of: {valid_store_sizes}"
valid_city_types = ['Tier 1', 'Tier 2', 'Tier 3']
if data['Store_Location_City_Type'] not in valid_city_types:
return False, f"Store_Location_City_Type must be one of: {valid_city_types}"
valid_store_types = ['Departmental Store', 'Supermarket Type1', 'Supermarket Type2', 'Food Mart']
if data['Store_Type'] not in valid_store_types:
return False, f"Store_Type must be one of: {valid_store_types}"
return True, "Validation passed"
except Exception as e:
return False, f"Validation error: {str(e)}"
def preprocess_for_prediction(data):
"""Preprocess input data for model prediction."""
try:
# Convert to DataFrame
if isinstance(data, dict):
df = pd.DataFrame([data])
else:
df = pd.DataFrame(data)
# Feature engineering functions (must match training)
def categorize_mrp(mrp):
if mrp <= 69.0:
return 'Low'
elif mrp <= 136.0:
return 'Medium_Low'
elif mrp <= 202.0:
return 'Medium_High'
else:
return 'High'
def categorize_weight(weight):
if weight <= 8.773:
return 'Light'
elif weight <= 12.89:
return 'Medium_Light'
elif weight <= 16.95:
return 'Medium_Heavy'
else:
return 'Heavy'
def categorize_store_age(age):
if age <= 20:
return 'New'
elif age <= 30:
return 'Established'
else:
return 'Legacy'
# Add engineered features
df['Product_MRP_Category'] = df['Product_MRP'].apply(categorize_mrp)
df['Product_Weight_Category'] = df['Product_Weight'].apply(categorize_weight)
df['Store_Age_Category'] = df['Store_Age'].apply(categorize_store_age)
df['City_Store_Type'] = df['Store_Location_City_Type'] + '_' + df['Store_Type']
df['Size_Type_Interaction'] = df['Store_Size'] + '_' + df['Store_Type']
# Transform using the preprocessing pipeline
processed_data = preprocessor.transform(df)
return processed_data, None
except Exception as e:
return None, str(e)
@app.route('/', methods=['GET'])
def home():
"""Home endpoint with API information."""
api_info = {
"message": "SuperKart Sales Forecasting API",
"version": "1.0",
"model_info": {
"name": model_artifacts['model_name'] if model_artifacts else "Model not loaded",
"training_date": model_artifacts['training_date'] if model_artifacts else "Unknown",
"version": model_artifacts['model_version'] if model_artifacts else "Unknown"
} if model_artifacts else {"status": "Model not loaded"},
"endpoints": {
"/": "API information",
"/health": "Health check",
"/predict": "Single prediction (POST)",
"/batch_predict": "Batch predictions (POST)",
"/model_info": "Model details"
},
"sample_input": {
"Product_Weight": 10.5,
"Product_Sugar_Content": "Low Sugar",
"Product_Allocated_Area": 0.15,
"Product_Type": "Fruits and Vegetables",
"Product_MRP": 150.0,
"Store_Size": "Medium",
"Store_Location_City_Type": "Tier 2",
"Store_Type": "Supermarket Type2",
"Store_Age": 15
}
}
return jsonify(api_info)
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint."""
health_status = {
"status": "healthy" if model is not None else "unhealthy",
"model_loaded": model is not None,
"timestamp": datetime.now().isoformat(),
"service": "SuperKart Sales Forecasting API"
}
return jsonify(health_status)
@app.route('/model_info', methods=['GET'])
def model_info():
"""Get detailed model information."""
if model_artifacts is None:
return jsonify({"error": "Model not loaded"}), 500
info = {
"model_name": model_artifacts['model_name'],
"training_date": model_artifacts['training_date'],
"model_version": model_artifacts['model_version'],
"performance_metrics": model_artifacts['performance_metrics'],
"feature_count": len(model_artifacts['feature_names']),
"model_type": type(model).__name__
}
return jsonify(info)
@app.route('/predict', methods=['POST'])
def predict():
"""Single prediction endpoint."""
try:
# Get JSON data from request
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
# Validate input data
is_valid, validation_message = validate_input_data(data)
if not is_valid:
return jsonify({"error": validation_message}), 400
# Preprocess data
processed_data, error = preprocess_for_prediction(data)
if error:
return jsonify({"error": f"Preprocessing failed: {error}"}), 400
# Make prediction
prediction = model.predict(processed_data)[0]
# Prepare response
response = {
"prediction": float(prediction),
"input_data": data,
"model_info": {
"model_name": model_artifacts['model_name'],
"prediction_timestamp": datetime.now().isoformat()
}
}
logger.info(f"Prediction made: {prediction:.2f}")
return jsonify(response)
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
logger.error(f"Traceback: {traceback.format_exc()}")
return jsonify({"error": f"Prediction failed: {str(e)}"}), 500
@app.route('/batch_predict', methods=['POST'])
def batch_predict():
"""Batch prediction endpoint."""
try:
# Get JSON data from request
data = request.get_json()
if data is None:
return jsonify({"error": "No JSON data provided"}), 400
# Ensure data is a list
if not isinstance(data, list):
return jsonify({"error": "Data must be a list of records"}), 400
if len(data) == 0:
return jsonify({"error": "Empty data list provided"}), 400
predictions = []
errors = []
for i, record in enumerate(data):
try:
# Validate input data
is_valid, validation_message = validate_input_data(record)
if not is_valid:
errors.append(f"Record {i}: {validation_message}")
predictions.append(None)
continue
# Preprocess data
processed_data, error = preprocess_for_prediction(record)
if error:
errors.append(f"Record {i}: Preprocessing failed - {error}")
predictions.append(None)
continue
# Make prediction
prediction = model.predict(processed_data)[0]
predictions.append(float(prediction))
except Exception as e:
errors.append(f"Record {i}: {str(e)}")
predictions.append(None)
# Prepare response
response = {
"predictions": predictions,
"total_records": len(data),
"successful_predictions": len([p for p in predictions if p is not None]),
"errors": errors if errors else None,
"model_info": {
"model_name": model_artifacts['model_name'],
"prediction_timestamp": datetime.now().isoformat()
}
}
logger.info(f"Batch prediction completed: {len(predictions)} records processed")
return jsonify(response)
except Exception as e:
logger.error(f"Batch prediction error: {str(e)}")
return jsonify({"error": f"Batch prediction failed: {str(e)}"}), 500
# Initialize the model when the app starts (Flask 3.x compatible)
def initialize():
"""Initialize the model on app startup."""
logger.info("Initializing SuperKart Sales Forecasting API...")
success = load_model()
if success:
logger.info("API initialization completed successfully")
else:
logger.error("API initialization failed - model could not be loaded")
# Call initialization immediately when module loads
with app.app_context():
initialize()
if __name__ == '__main__':
# Load model
if load_model():
print("[SUCCESS] Model loaded successfully")
print("[STARTING] SuperKart Sales Forecasting API...")
app.run(host='0.0.0.0', port=8080, debug=False)
else:
print("[ERROR] Failed to load model. Please check model file.")