Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import json | |
| from typing import Any, Dict, List | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import pandas as pd | |
| from loader import load_store_model | |
| app = Flask(__name__) | |
| CORS(app) | |
| def health(): | |
| return jsonify({"status": "ok", "message": "SuperKart backend running"}), 200 | |
| def _predict_single(store_id: Any, features: Dict[str, Any]): | |
| model, meta = load_store_model(str(store_id)) | |
| df = pd.DataFrame([features]) | |
| yhat = model.predict(df) | |
| return float(yhat[0]), meta | |
| #@app.post('/predict') | |
| def predict(): | |
| """POST JSON: | |
| { | |
| "store_id": "101", | |
| "features": { ... single row ... } | |
| } | |
| OR | |
| { | |
| "store_id": "101", | |
| "features_list": [ {...}, {...} ] # multiple rows for same store | |
| } | |
| """ | |
| try: | |
| payload = request.get_json(force=True, silent=False) | |
| except Exception as e: | |
| return jsonify({"error": f"Invalid JSON: {e}"}), 400 | |
| if not payload: | |
| return jsonify({"error": "Empty payload"}), 400 | |
| store_id = str(payload.get("store_id", "")).strip() | |
| if not store_id: | |
| return jsonify({"error": "Missing 'store_id'"}), 400 | |
| try: | |
| model, meta = load_store_model(store_id) | |
| except FileNotFoundError as e: | |
| return jsonify({"error": str(e)}), 404 | |
| if "features" in payload: | |
| df = pd.DataFrame([payload["features"]]) | |
| yhat = model.predict(df) | |
| return jsonify({ | |
| "store_id": store_id, | |
| "n_rows": 1, | |
| "predictions": [float(yhat[0])], | |
| "model": meta.get("model"), | |
| "metrics": meta.get("metrics", {}), | |
| "features_used": meta.get("features", []) | |
| }), 200 | |
| elif "features_list" in payload: | |
| rows = payload["features_list"] | |
| if not isinstance(rows, list) or len(rows) == 0: | |
| return jsonify({"error": "'features_list' must be a non-empty list"}), 400 | |
| df = pd.DataFrame(rows) | |
| yhat = model.predict(df) | |
| return jsonify({ | |
| "store_id": store_id, | |
| "n_rows": len(df), | |
| "predictions": [float(v) for v in yhat], | |
| "model": meta.get("model"), | |
| "metrics": meta.get("metrics", {}), | |
| "features_used": meta.get("features", []) | |
| }), 200 | |
| else: | |
| return jsonify({"error": "Provide either 'features' or 'features_list'"}), 400 | |
| def predict_batch(): | |
| """Multipart form with a CSV file: | |
| - expects a 'file' field | |
| - CSV must include a 'store_id' column and the necessary features. | |
| Will route each row to that store's model and return merged results. | |
| """ | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file uploaded with field name 'file'"}), 400 | |
| f = request.files["file"] | |
| try: | |
| df = pd.read_csv(f) | |
| except Exception as e: | |
| return jsonify({"error": f"Failed to read CSV: {e}"}), 400 | |
| if "store_id" not in df.columns: | |
| return jsonify({"error": "CSV must include 'store_id' column"}), 400 | |
| preds = [] | |
| errors = [] | |
| # Simple cache for models during batch call | |
| cache = {} | |
| for idx, row in df.iterrows(): | |
| sid = str(row["store_id"]) | |
| feats = row.drop(labels=["store_id"]).to_dict() | |
| try: | |
| if sid not in cache: | |
| cache[sid] = load_store_model(sid) | |
| model, meta = cache[sid] | |
| yhat = model.predict(pd.DataFrame([feats]))[0] | |
| preds.append(float(yhat)) | |
| except FileNotFoundError as e: | |
| preds.append(None) | |
| errors.append({"row": int(idx), "store_id": sid, "error": str(e)}) | |
| except Exception as e: | |
| preds.append(None) | |
| errors.append({"row": int(idx), "store_id": sid, "error": f"{type(e).__name__}: {e}"}) | |
| df_out = df.copy() | |
| df_out["predicted_sales"] = preds | |
| # Return as JSON (truncated) and CSV file bytes | |
| buf = io.StringIO() | |
| df_out.to_csv(buf, index=False) | |
| buf.seek(0) | |
| return jsonify({ | |
| "rows": len(df_out), | |
| "errors": errors, | |
| "csv": buf.getvalue() | |
| }), 200 | |
| if __name__ == "__main__": | |
| # Local dev | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=True) | |