kmsmohamedansar commited on
Commit
c696096
·
verified ·
1 Parent(s): 63e71f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -1,38 +1,41 @@
 
1
  import gradio as gr
2
  import pickle
3
  import numpy as np
4
  from sklearn.preprocessing import StandardScaler
5
 
6
- # 1) Load your model once at startup
7
- with open("rf_model.pkl", "rb") as f:
8
- MODEL = pickle.load(f)
 
9
 
10
- # 2) A simple predict function
11
- def predict(f1, f2, f3, f4, f5):
12
- X = np.array([[f1, f2, f3, f4, f5]])
13
- X_scaled = StandardScaler().fit_transform(X)
14
- return int(MODEL.predict(X_scaled)[0])
 
 
15
 
16
- # 3) Build your Gradio interface
17
  demo = gr.Interface(
18
  fn=predict,
19
  inputs=[
20
- gr.Slider(0, 10, step=1, value=5, label="Feature 1"),
21
- gr.Slider(0, 10, step=1, value=3, label="Feature 2"),
22
- gr.Slider(0, 10, step=1, value=7, label="Feature 3"),
23
- gr.Slider(0, 10, step=1, value=6, label="Feature 4"),
24
- gr.Slider(0, 10, step=1, value=4, label="Feature 5"),
25
  ],
26
- outputs=gr.Label(label="Prediction"),
27
  title="TaskMaster Job Scheduler",
28
- description="Enter five feature values to get a RandomForest prediction."
29
  )
30
 
 
31
  if __name__ == "__main__":
32
- # This tells Gradio to bind to all interfaces, on port 7860,
33
- # and to BLOCK the Python thread (so the container stays up).
34
  demo.launch(
35
  server_name="0.0.0.0",
36
  server_port=7860,
37
- prevent_thread_lock=False
38
  )
 
1
+ # app.py
2
  import gradio as gr
3
  import pickle
4
  import numpy as np
5
  from sklearn.preprocessing import StandardScaler
6
 
7
+ # 1) Load your trained model (make sure rf_model.pkl is in the repo root)
8
+ def load_model():
9
+ with open("rf_model.pkl", "rb") as f:
10
+ return pickle.load(f)
11
 
12
+ # 2) Prediction function: takes five numeric inputs, scales them, returns class
13
+ def predict(feature1, feature2, feature3, feature4, feature5):
14
+ model = load_model()
15
+ x = np.array([[feature1, feature2, feature3, feature4, feature5]])
16
+ # NOTE: We fit_transform here for demo; in prod you'd persist the scaler too.
17
+ x_scaled = StandardScaler().fit_transform(x)
18
+ return str(model.predict(x_scaled)[0])
19
 
20
+ # 3) Build the Gradio interface
21
  demo = gr.Interface(
22
  fn=predict,
23
  inputs=[
24
+ gr.Slider(0, 10, value=5, label="Feature 1"),
25
+ gr.Slider(0, 10, value=3, label="Feature 2"),
26
+ gr.Slider(0, 10, value=7, label="Feature 3"),
27
+ gr.Slider(0, 10, value=6, label="Feature 4"),
28
+ gr.Slider(0, 10, value=4, label="Feature 5"),
29
  ],
30
+ outputs=gr.Textbox(label="Predicted Class"),
31
  title="TaskMaster Job Scheduler",
32
+ description="Enter five feature values to get a RandomForest prediction.",
33
  )
34
 
35
+ # 4) Launch with SSR turned off for Spaces
36
  if __name__ == "__main__":
 
 
37
  demo.launch(
38
  server_name="0.0.0.0",
39
  server_port=7860,
40
+ ssr_mode=False # disable server-side rendering on the Space
41
  )