Sudharsanamr commited on
Commit
7eef64a
·
verified ·
1 Parent(s): 4053706

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +334 -39
src/streamlit_app.py CHANGED
@@ -1,40 +1,335 @@
1
- import altair as alt
2
- import numpy as np
 
 
 
 
 
 
 
 
 
 
3
  import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_superkart_app.py
2
+ # Streamlit app for Super Kart — uses a remote Gradio backend and/or a local model file
3
+ # This file is written to work even when `streamlit` is NOT available in the environment.
4
+ # If `streamlit` is installed, the interactive web UI will run as intended.
5
+ # If `streamlit` is missing, the script falls back to a CLI/test mode so you can still
6
+ # validate remote endpoint behavior and quick local model tests.
7
+
8
+ import sys
9
+ import io
10
+ import traceback
11
+ import json
12
+ import requests
13
  import pandas as pd
14
+ import numpy as np
15
+ import joblib
16
+
17
+ # ----------------------
18
+ # Configuration
19
+ # ----------------------
20
+ DEFAULT_REMOTE = "https://sudharsanamr-superkart.hf.space/gradio_api/call/predict"
21
+ DEFAULT_FN_INDEX = 0
22
+
23
+ # ----------------------
24
+ # Utility functions
25
+ # ----------------------
26
+
27
+ def predict_remote(df, endpoint=DEFAULT_REMOTE, fn_index=DEFAULT_FN_INDEX, timeout=30):
28
+ """Send each row in df to the Gradio-style endpoint and return a list of responses.
29
+ Returns (results, errors) where results is a list of parsed responses and errors is a list
30
+ of (row_index, error_info).
31
+ """
32
+ results = []
33
+ errors = []
34
+ for i, row in df.iterrows():
35
+ payload = {"data": [row.tolist()], "fn_index": int(fn_index)}
36
+ try:
37
+ r = requests.post(endpoint, json=payload, timeout=timeout)
38
+ if r.status_code == 200:
39
+ try:
40
+ j = r.json()
41
+ if isinstance(j, dict) and 'data' in j:
42
+ results.append(j['data'])
43
+ else:
44
+ results.append(j)
45
+ except Exception:
46
+ results.append(r.text)
47
+ else:
48
+ errors.append((i, f"HTTP {r.status_code}", r.text[:1000]))
49
+ except Exception as e:
50
+ errors.append((i, str(e)))
51
+ return results, errors
52
+
53
+
54
+ def predict_with_model(df, model):
55
+ """Call model.predict on df. If model doesn't expose predict, try calling it as a callable.
56
+ If neither works, raise a ValueError.
57
+ """
58
+ if hasattr(model, 'predict'):
59
+ return model.predict(df)
60
+ elif callable(model):
61
+ return model(df)
62
+ else:
63
+ raise ValueError("Provided model is not callable and has no .predict method")
64
+
65
+
66
+ # A tiny dummy model used for CLI tests when no model file is provided.
67
+ class DummyModel:
68
+ def predict(self, X):
69
+ # simple deterministic output for testing: sum of numeric columns per row
70
+ numeric = X.select_dtypes(include=[np.number])
71
+ if numeric.shape[1] == 0:
72
+ # fallback: return zeros
73
+ return np.zeros(len(X)).tolist()
74
+ return numeric.sum(axis=1).tolist()
75
+
76
+
77
+ # ----------------------
78
+ # Main: Streamlit UI (if available)
79
+ # ----------------------
80
+ try:
81
+ import streamlit as st # type: ignore
82
+ ST_AVAILABLE = True
83
+ except Exception:
84
+ ST_AVAILABLE = False
85
+
86
+ if ST_AVAILABLE:
87
+ st.set_page_config(page_title="Super Kart — Prediction App", layout="wide")
88
+ st.title("Super Kart — Prediction App")
89
+
90
+ # Sidebar: choose mode
91
+ mode = st.sidebar.selectbox("Prediction mode", ["Remote API (Gradio)", "Local model (.joblib)"])
92
+
93
+ # Initialize model variable in module scope so it's always defined
94
+ model = None
95
+ endpoint = DEFAULT_REMOTE
96
+
97
+ if mode == "Remote API (Gradio)":
98
+ st.sidebar.write("Remote endpoint (editable)")
99
+ endpoint = st.sidebar.text_input("Gradio API endpoint", value=DEFAULT_REMOTE)
100
+ if st.sidebar.button("Test endpoint"):
101
+ st.sidebar.info("Testing endpoint...")
102
+ try:
103
+ probe = {"data": [[0]], "fn_index": 0}
104
+ r = requests.post(endpoint, json=probe, timeout=10)
105
+ st.sidebar.write(f"Status: {r.status_code}")
106
+ try:
107
+ st.sidebar.write(r.json())
108
+ except Exception:
109
+ st.sidebar.write(r.text[:1000])
110
+ except Exception as e:
111
+ st.sidebar.error(f"Endpoint test failed: {e}")
112
+
113
+ else:
114
+ st.sidebar.write("Upload a local scikit-learn model (.joblib)")
115
+ uploaded_model = st.sidebar.file_uploader("Upload model (.joblib)", type=["joblib", "pkl"], key="model_uploader")
116
+ if uploaded_model is not None:
117
+ try:
118
+ bytes_data = uploaded_model.read()
119
+ model = joblib.load(io.BytesIO(bytes_data))
120
+ st.sidebar.success("Model loaded — ready for predictions")
121
+ except Exception as e:
122
+ st.sidebar.error(f"Failed to load model: {e}")
123
+ st.sidebar.text(traceback.format_exc())
124
+
125
+ st.markdown("---")
126
+
127
+ st.header("Upload input data")
128
+ uploaded_file = st.file_uploader("Upload CSV (rows = samples). If empty, use manual input below.", type=["csv"])
129
+
130
+ input_df = None
131
+ if uploaded_file is not None:
132
+ try:
133
+ input_df = pd.read_csv(uploaded_file)
134
+ st.write("Preview of uploaded data:")
135
+ st.dataframe(input_df.head())
136
+ except Exception as e:
137
+ st.error(f"Failed to read CSV: {e}")
138
+
139
+ st.markdown("### Or enter single sample manually")
140
+ manual_input = None
141
+ with st.form("manual_form"):
142
+ col1, col2 = st.columns(2)
143
+ sample_text = st.text_area("Paste a single sample as comma-separated values (no header), or JSON list. Example: 12,3.5,0,1", height=80)
144
+ submit = st.form_submit_button("Use manual sample")
145
+ if submit and sample_text.strip():
146
+ s = sample_text.strip()
147
+ try:
148
+ if s.startswith("["):
149
+ vals = pd.read_json(io.StringIO(s), typ='series')
150
+ manual_input = pd.DataFrame([vals.tolist()])
151
+ else:
152
+ parts = [x.strip() for x in s.split(',') if x.strip()!='']
153
+ parsed = []
154
+ for p in parts:
155
+ try:
156
+ if '.' in p:
157
+ parsed.append(float(p))
158
+ else:
159
+ parsed.append(int(p))
160
+ except:
161
+ parsed.append(p)
162
+ manual_input = pd.DataFrame([parsed])
163
+ st.success("Manual sample parsed")
164
+ st.write(manual_input)
165
+ except Exception as e:
166
+ st.error(f"Failed to parse manual sample: {e}")
167
+
168
+ if input_df is not None:
169
+ df_to_predict = input_df
170
+ elif manual_input is not None:
171
+ df_to_predict = manual_input
172
+ else:
173
+ df_to_predict = None
174
+
175
+ if df_to_predict is None:
176
+ st.info("Provide an input CSV or a manual sample to get predictions.")
177
+ else:
178
+ st.markdown("---")
179
+ st.header("Prepare & Predict")
180
+ st.write("Columns detected:", list(df_to_predict.columns))
181
+
182
+ st.write("Select feature columns to use for prediction (order matters):")
183
+ cols = st.multiselect("Feature columns", options=list(df_to_predict.columns), default=list(df_to_predict.columns))
184
+
185
+ if not cols:
186
+ st.error("Select at least one column")
187
+ else:
188
+ X = df_to_predict[cols].copy()
189
+ for c in X.columns:
190
+ if X[c].dtype == object:
191
+ try:
192
+ X[c] = pd.to_numeric(X[c])
193
+ except:
194
+ pass
195
+
196
+ st.write("Prepared features (first rows):")
197
+ st.dataframe(X.head())
198
+
199
+ if mode == "Local model (.joblib)":
200
+ if model is None:
201
+ st.error("No local model loaded. Upload a .joblib model in the sidebar.")
202
+ else:
203
+ if st.button("Run local predictions"):
204
+ try:
205
+ preds = predict_with_model(X, model)
206
+ st.success("Predictions complete")
207
+ out = pd.DataFrame({"prediction": preds})
208
+ st.dataframe(out)
209
+
210
+ csv = out.to_csv(index=False)
211
+ st.download_button("Download predictions CSV", data=csv, file_name="predictions.csv")
212
+ except Exception as e:
213
+ st.error(f"Local prediction failed: {e}")
214
+ st.text(traceback.format_exc())
215
+
216
+ else:
217
+ st.write("Remote API endpoint:", endpoint)
218
+ fn_index = st.number_input("fn_index (Gradio function index)", value=0, min_value=0)
219
+ if st.button("Send to remote API"):
220
+ with st.spinner("Sending requests..."):
221
+ results, errors = predict_remote(X, endpoint=endpoint, fn_index=fn_index)
222
+
223
+ if results:
224
+ st.success(f"Received {len(results)} responses")
225
+ try:
226
+ flattened = [r[0] if isinstance(r, list) and len(r)>0 else r for r in results]
227
+ out_df = pd.DataFrame({"prediction": flattened})
228
+ st.dataframe(out_df)
229
+ st.download_button("Download predictions CSV", data=out_df.to_csv(index=False), file_name="remote_predictions.csv")
230
+ except Exception:
231
+ st.write(results)
232
+
233
+ if errors:
234
+ st.error(f"{len(errors)} errors occurred — showing first 5")
235
+ for e in errors[:5]:
236
+ st.write(e)
237
+
238
+ st.markdown("---")
239
+ st.write("Notes:\n- Many Gradio spaces expect POST body like: {\\\"data\\\": [[...inputs...]], \\\"fn_index\\\": 0}. If your space uses a different format, adjust the payload.\n- If you will upload your model for local predictions, upload it in the sidebar as a joblib file.")
240
+
241
+ # Requirements hint (properly closed triple-backticks)
242
+ st.sidebar.markdown("**Requirements**\n```\nstreamlit\npandas\nscikit-learn\njoblib\nrequests\n```")
243
+
244
+ # ----------------------
245
+ # CLI / Test Mode (runs when streamlit is not available)
246
+ # ----------------------
247
+ else:
248
+ def _print_banner():
249
+ print("Streamlit is not available in this environment. Running in CLI/test mode.")
250
+ print("To run the interactive app, install streamlit and run: streamlit run streamlit_superkart_app.py")
251
+ print("Default remote endpoint:", DEFAULT_REMOTE)
252
+ print("")
253
+
254
+ def _cli_demo():
255
+ _print_banner()
256
+ # Create a small test dataframe
257
+ df = pd.DataFrame({
258
+ 'feature_a': [1.0, 2.5, 3.3],
259
+ 'feature_b': [0, 1, 0],
260
+ 'category': ['x', 'y', 'z']
261
+ })
262
+ print("Test input:")
263
+ print(df)
264
+
265
+ # Try remote predict (best-effort; network must be allowed in environment)
266
+ print('\n--- Remote endpoint test ---')
267
+ try:
268
+ results, errors = predict_remote(df[['feature_a', 'feature_b']], endpoint=DEFAULT_REMOTE)
269
+ print(f"Remote results (count={len(results)}):")
270
+ for r in results:
271
+ print(r)
272
+ if errors:
273
+ print(f"Remote errors (count={len(errors)}):")
274
+ for e in errors:
275
+ print(e)
276
+ except Exception as e:
277
+ print("Remote test failed:", str(e))
278
+ traceback.print_exc()
279
+
280
+ # Try local dummy model predict
281
+ print('\n--- Local dummy model test ---')
282
+ dummy = DummyModel()
283
+ try:
284
+ preds = predict_with_model(df[['feature_a', 'feature_b']], dummy)
285
+ print('Dummy model predictions:', preds)
286
+ except Exception as e:
287
+ print('Local dummy model failed:', e)
288
+ traceback.print_exc()
289
+
290
+ # If user provided a model filename as CLI arg, try loading it and predicting
291
+ if len(sys.argv) > 1:
292
+ model_path = sys.argv[1]
293
+ print(f"\n--- Loading local model from: {model_path}")
294
+ try:
295
+ m = joblib.load(model_path)
296
+ p = predict_with_model(df[['feature_a', 'feature_b']], m)
297
+ print('Predictions from provided model:', p)
298
+ except Exception as e:
299
+ print('Failed to load/predict with provided model:', e)
300
+ traceback.print_exc()
301
+
302
+ # Add simple tests (these serve as test cases requested)
303
+ def _run_tests():
304
+ print('\n=== Running built-in tests ===')
305
+ # Test 1: predict_remote should return lists (may be empty if network blocked)
306
+ df = pd.DataFrame({'a':[1,2], 'b':[3,4]})
307
+ try:
308
+ results, errors = predict_remote(df, endpoint=DEFAULT_REMOTE)
309
+ print('predict_remote returned:', len(results), 'results and', len(errors), 'errors')
310
+ except Exception as e:
311
+ print('predict_remote raised exception (this may be due to network restrictions):', e)
312
+
313
+ # Test 2: predict_with_model with DummyModel
314
+ dummy = DummyModel()
315
+ out = predict_with_model(df, dummy)
316
+ assert len(out) == len(df), 'DummyModel should return same length output as input rows'
317
+ print('DummyModel test passed — output:', out)
318
+
319
+ # Test 3: predict_with_model error case
320
+ try:
321
+ class BadModel: pass
322
+ bad = BadModel()
323
+ try:
324
+ predict_with_model(df, bad)
325
+ print('ERROR: predict_with_model should have raised for BadModel')
326
+ except ValueError:
327
+ print('predict_with_model correctly raised ValueError for invalid model')
328
+ except AssertionError as e:
329
+ print('Test assertion failed:', e)
330
+
331
+ print('All CLI tests completed.')
332
+
333
+ if __name__ == '__main__':
334
+ _cli_demo()
335
+ _run_tests()