samarth-kamble commited on
Commit
1a5e0e7
·
verified ·
1 Parent(s): 0d9808a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -19
app.py CHANGED
@@ -1,22 +1,14 @@
1
  import gradio as gr
2
- import joblib
3
 
4
- # Load models using joblib
5
- knn_model = joblib.load("knn_model.joblib")
6
- rf_model = joblib.load("rf_model.joblib")
7
- svm_model = joblib.load("svm_model.joblib")
8
-
9
- model_map = {
10
- "KNN": knn_model,
11
- "Random Forest": rf_model,
12
- "SVM": svm_model
13
- }
14
 
15
  # Prediction function
16
- def predict(model_name, sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH):
17
  try:
18
- model = model_map[model_name]
19
-
20
  sex = int(sex)
21
  pregnant = int(pregnant)
22
  on_thyroxine = int(on_thyroxine)
@@ -26,9 +18,9 @@ def predict(model_name, sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH):
26
  FTI = float(FTI)
27
  TSH = float(TSH)
28
 
29
- prediction = model.predict([[sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH]])
30
  label_map = {0: "Hyperthyroid", 1: "Hypothyroid", 2: "Negative"}
31
- return f"Prediction using {model_name}: {label_map.get(prediction[0], 'Unknown')}"
32
  except Exception as e:
33
  return f"Error: {str(e)}"
34
 
@@ -36,7 +28,6 @@ def predict(model_name, sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH):
36
  demo = gr.Interface(
37
  fn=predict,
38
  inputs=[
39
- gr.Dropdown(["SVM", "KNN", "Random Forest"], label="Select Model"),
40
  gr.Radio([0, 1], label="Sex (0: Female, 1: Male)"),
41
  gr.Radio([0, 1], label="Pregnant (0: No, 1: Yes)"),
42
  gr.Radio([0, 1], label="On Thyroxine (0: No, 1: Yes)"),
@@ -47,8 +38,8 @@ demo = gr.Interface(
47
  gr.Number(label="TSH"),
48
  ],
49
  outputs="text",
50
- title="Hyperthyroid Prediction (Multi-Model)",
51
- description="Select a model and input medical features to predict thyroid condition."
52
  )
53
 
54
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import pickle
3
 
4
+ # Load the SVM model
5
+ with open("svm_model.pkl", "rb") as f:
6
+ svm_model = pickle.load(f)
 
 
 
 
 
 
 
7
 
8
  # Prediction function
9
+ def predict(sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH):
10
  try:
11
+ # Ensure inputs are correctly typed
 
12
  sex = int(sex)
13
  pregnant = int(pregnant)
14
  on_thyroxine = int(on_thyroxine)
 
18
  FTI = float(FTI)
19
  TSH = float(TSH)
20
 
21
+ prediction = svm_model.predict([[sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH]])
22
  label_map = {0: "Hyperthyroid", 1: "Hypothyroid", 2: "Negative"}
23
+ return f"Prediction: {label_map.get(prediction[0], 'Unknown')}"
24
  except Exception as e:
25
  return f"Error: {str(e)}"
26
 
 
28
  demo = gr.Interface(
29
  fn=predict,
30
  inputs=[
 
31
  gr.Radio([0, 1], label="Sex (0: Female, 1: Male)"),
32
  gr.Radio([0, 1], label="Pregnant (0: No, 1: Yes)"),
33
  gr.Radio([0, 1], label="On Thyroxine (0: No, 1: Yes)"),
 
38
  gr.Number(label="TSH"),
39
  ],
40
  outputs="text",
41
+ title="Hyperthyroid Prediction using SVM",
42
+ description="Enter 8 medical inputs to predict thyroid condition using an SVM model."
43
  )
44
 
45
  if __name__ == "__main__":