mdsalmon159 commited on
Commit
0148d6b
·
verified ·
1 Parent(s): b24d216

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +43 -111
app.py CHANGED
@@ -1,27 +1,24 @@
1
- import os
2
- import logging
3
- import numpy as np
4
- import pandas as pd
5
- import joblib
6
  from flask import Flask, request, jsonify
7
 
8
- # try to enable CORS if available (optional)
9
  try:
10
  from flask_cors import CORS
11
- _CORS_AVAILABLE = True
12
  except Exception:
13
- _CORS_AVAILABLE = False
14
 
15
- # Create Flask app (WSGI callable must be named `app`)
16
  app = Flask(__name__)
17
- if _CORS_AVAILABLE:
18
  CORS(app)
19
 
20
- # logging (helps debugging in Spaces)
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
- # Load the trained model (guarded)
25
  MODEL_PATH = os.environ.get("MODEL_PATH", "superkart_prediction.joblib")
26
  model = None
27
  load_error = None
@@ -36,131 +33,66 @@ else:
36
  load_error = f"Model file not found at {MODEL_PATH}"
37
  logger.warning(load_error)
38
 
39
-
40
  @app.route("/", methods=["GET"])
41
  def home():
42
- return jsonify({"message": "Welcome to the SuperKart Sales Prediction API!"}), 200
43
-
44
 
45
  @app.route("/v1/sales", methods=["POST"])
46
- def predict_sales_single():
47
- """
48
- Expects a JSON body with the features required by the model.
49
- Example:
50
- {
51
- "Product_Id": 123,
52
- "Product_Weight": 1.23,
53
- "Product_Sugar_Content": "Low",
54
- "Product_Allocated_Area": 10,
55
- "Product_Type": "TypeA",
56
- "Product_MRP": 99.99,
57
- "Store_Id": "S1",
58
- "Store_Establishment_Year": 1998,
59
- "Store_Size": "Small",
60
- "Store_Location_City_Type": "Tier1",
61
- "Store_Type": "Supermarket",
62
- "log_output": true # optional, True if model predicts log(sales)
63
- }
64
- """
65
  if model is None:
66
- return jsonify({"error": "Model not loaded", "details": load_error}), 500
67
-
68
  try:
69
  data = request.get_json(force=True)
70
- if not data:
71
- return jsonify({"error": "Invalid or empty JSON body"}), 400
72
-
73
- # expected columns (adapt to your model's expected features/order)
74
  expected_cols = [
75
- "Product_Id", "Product_Weight", "Product_Sugar_Content", "Product_Allocated_Area",
76
- "Product_Type", "Product_MRP", "Store_Id", "Store_Establishment_Year",
77
- "Store_Size", "Store_Location_City_Type", "Store_Type"
78
  ]
79
-
80
- # build a row dict with missing keys set to None
81
- row = {col: data.get(col, None) for col in expected_cols}
82
-
83
- # convert to DataFrame and predict
84
- input_df = pd.DataFrame([row])
85
- pred = model.predict(input_df)
86
-
87
- # handle log-output option
88
- log_output = bool(data.get("log_output", False))
89
- sale = float(np.exp(pred[0])) if log_output else float(pred[0])
90
- sale = round(sale, 2)
91
-
92
- return jsonify({"predicted_sales": sale}), 200
93
-
94
  except Exception as e:
95
- logger.exception("Single prediction failed")
96
- return jsonify({"error": "Prediction failed", "details": str(e)}), 500
97
-
98
 
99
  @app.route("/v1/sales/batch", methods=["POST"])
100
- def predict_sales_batch():
101
- """
102
- Accepts:
103
- - multipart/form-data file upload with key 'file' (CSV), OR
104
- - JSON body with key 'data' (list of records) OR
105
- - JSON body as a list of records
106
- Returns JSON mapping of IDs (if present) or indices to predicted sales.
107
- """
108
  if model is None:
109
- return jsonify({"error": "Model not loaded", "details": load_error}), 500
110
-
111
  try:
112
- # 1) CSV upload
113
  if "file" in request.files:
114
- file = request.files["file"]
115
- df = pd.read_csv(file)
116
  else:
117
- # 2) JSON body
118
- json_body = request.get_json(force=True)
119
- if json_body is None:
120
- return jsonify({"error": "No file uploaded and no JSON body provided"}), 400
121
-
122
- if isinstance(json_body, dict) and "data" in json_body:
123
- df = pd.DataFrame(json_body["data"])
124
- elif isinstance(json_body, list):
125
- df = pd.DataFrame(json_body)
126
  else:
127
- return jsonify({"error": "JSON must be a list of records or contain 'data' key"}), 400
128
-
129
  if df.empty:
130
- return jsonify({"error": "Input data is empty"}), 400
131
-
132
- # Predict (model should accept DataFrame columns as provided)
133
  preds = model.predict(df).tolist()
134
-
135
- # Handle optional query param or JSON flag for log output (default False)
136
- log_output_flag = request.args.get("log_output", "false").lower() == "true"
137
- if log_output_flag:
138
- preds = [float(round(np.exp(p), 2)) for p in preds]
139
  else:
140
- preds = [float(round(p, 2)) for p in preds]
141
-
142
- # map back to ID column if present
143
- id_col = None
144
- for candidate in ("id", "ID", "Product_Id"):
145
- if candidate in df.columns:
146
- id_col = candidate
147
- break
148
-
149
  if id_col:
150
- ids = df[id_col].astype(str).tolist()
151
- out = dict(zip(ids, preds))
152
  else:
153
  out = {str(i): preds[i] for i in range(len(preds))}
154
-
155
  return jsonify({"predictions": out}), 200
156
-
157
  except Exception as e:
158
- logger.exception("Batch prediction failed")
159
- return jsonify({"error": "Batch prediction failed", "details": str(e)}), 500
160
-
161
 
162
- # local debug server; in Spaces/Gunicorn we expose 'app' as WSGI callable
163
  if __name__ == "__main__":
164
  port = int(os.environ.get("PORT", 7860))
165
- logger.info(f"Starting local Flask server on port {port}")
166
  app.run(host="0.0.0.0", port=port, debug=True)
 
 
1
+
2
+ %%bash
3
+ cat > app.py <<'PY'
4
+ import os, logging, joblib, numpy as np, pandas as pd
 
5
  from flask import Flask, request, jsonify
6
 
7
+ # Attempt optional CORS
8
  try:
9
  from flask_cors import CORS
10
+ _CORS = True
11
  except Exception:
12
+ _CORS = False
13
 
14
+ # WSGI callable expected by gunicorn must be named `app`
15
  app = Flask(__name__)
16
+ if _CORS:
17
  CORS(app)
18
 
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
22
  MODEL_PATH = os.environ.get("MODEL_PATH", "superkart_prediction.joblib")
23
  model = None
24
  load_error = None
 
33
  load_error = f"Model file not found at {MODEL_PATH}"
34
  logger.warning(load_error)
35
 
 
36
  @app.route("/", methods=["GET"])
37
  def home():
38
+ return jsonify({"message":"API up"}), 200
 
39
 
40
  @app.route("/v1/sales", methods=["POST"])
41
+ def predict_single():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if model is None:
43
+ return jsonify({"error":"Model not loaded","details": load_error}), 500
 
44
  try:
45
  data = request.get_json(force=True)
 
 
 
 
46
  expected_cols = [
47
+ "Product_Id","Product_Weight","Product_Sugar_Content","Product_Allocated_Area",
48
+ "Product_Type","Product_MRP","Store_Id","Store_Establishment_Year",
49
+ "Store_Size","Store_Location_City_Type","Store_Type"
50
  ]
51
+ row = {c: data.get(c, None) for c in expected_cols}
52
+ df = pd.DataFrame([row])
53
+ pred = model.predict(df)
54
+ log_out = bool(data.get("log_output", False))
55
+ val = float(np.exp(pred[0])) if log_out else float(pred[0])
56
+ return jsonify({"predicted_sales": round(val,2)}), 200
 
 
 
 
 
 
 
 
 
57
  except Exception as e:
58
+ logger.exception("predict_single failed")
59
+ return jsonify({"error":"prediction failed","details": str(e)}), 500
 
60
 
61
  @app.route("/v1/sales/batch", methods=["POST"])
62
+ def predict_batch():
 
 
 
 
 
 
 
63
  if model is None:
64
+ return jsonify({"error":"Model not loaded","details": load_error}), 500
 
65
  try:
 
66
  if "file" in request.files:
67
+ df = pd.read_csv(request.files["file"])
 
68
  else:
69
+ jb = request.get_json(force=True)
70
+ if isinstance(jb, dict) and "data" in jb:
71
+ df = pd.DataFrame(jb["data"])
72
+ elif isinstance(jb, list):
73
+ df = pd.DataFrame(jb)
 
 
 
 
74
  else:
75
+ return jsonify({"error":"No file or JSON data provided"}), 400
 
76
  if df.empty:
77
+ return jsonify({"error":"Input empty"}), 400
 
 
78
  preds = model.predict(df).tolist()
79
+ log_flag = request.args.get("log_output","false").lower() == "true"
80
+ if log_flag:
81
+ preds = [round(float(np.exp(p)),2) for p in preds]
 
 
82
  else:
83
+ preds = [round(float(p),2) for p in preds]
84
+ id_col = next((c for c in ("id","ID","Product_Id") if c in df.columns), None)
 
 
 
 
 
 
 
85
  if id_col:
86
+ keys = df[id_col].astype(str).tolist()
87
+ out = dict(zip(keys, preds))
88
  else:
89
  out = {str(i): preds[i] for i in range(len(preds))}
 
90
  return jsonify({"predictions": out}), 200
 
91
  except Exception as e:
92
+ logger.exception("predict_batch failed")
93
+ return jsonify({"error":"batch prediction failed","details": str(e)}), 500
 
94
 
 
95
  if __name__ == "__main__":
96
  port = int(os.environ.get("PORT", 7860))
 
97
  app.run(host="0.0.0.0", port=port, debug=True)
98
+ PY