File size: 2,202 Bytes
b59fe76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | from flask import Flask, request, jsonify
import pandas as pd
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
import numpy as np
app = Flask(__name__)
# Load the initial model and scaler
MODEL_PATH = 'garbage_model.pkl'
SCALER_PATH = 'scaler.pkl'
DATA_PATH = 'real_world_garbage_data.csv'
def retrain_model():
"""Internal function to refresh the model when data updates."""
df = pd.read_csv(DATA_PATH)
X = df[['Hour', 'Weight_kg', 'Distance_cm', 'Is_Weekend']]
y = df['Status']
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
model = RandomForestClassifier(n_estimators=100, max_depth=5)
model.fit(X_scaled, y)
# Save the updated versions
joblib.dump(model, MODEL_PATH)
joblib.dump(scaler, SCALER_PATH)
return model, scaler
# Load current versions on startup
try:
model = joblib.load(MODEL_PATH)
scaler = joblib.load(SCALER_PATH)
except:
print("No model found. Training initial model...")
model, scaler = retrain_model()
@app.route('/predict', methods=['POST'])
def predict():
"""Endpoint for real-time sensor predictions."""
data = request.get_json()
# Required format: [Hour, Weight_kg, Distance_cm, Is_Weekend]
features = np.array([[data['hour'], data['weight'], data['distance'], data['weekend']]])
# Scale and Predict
features_scaled = scaler.transform(features)
prediction = int(model.predict(features_scaled)[0])
return jsonify({'status': prediction, 'message': 'Prediction successful'})
@app.route('/update', methods=['POST'])
def update_and_retrain():
"""Endpoint to add new data and trigger retraining."""
global model, scaler
new_data = request.get_json() # List of data points
df_new = pd.DataFrame(new_data)
df_new.to_csv(DATA_PATH, mode='a', header=False, index=False)
# Trigger the retraining loop
model, scaler = retrain_model()
return jsonify({'message': 'Model retrained successfully with new data'})
if __name__ == '__main__':
app.run(port=5000, debug=True) |