samarth-kamble commited on
Commit
7860eb0
·
verified ·
1 Parent(s): 592a331

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -8
app.py CHANGED
@@ -1,24 +1,41 @@
1
  import gradio as gr
2
  import pickle
3
 
4
- # Load model
 
 
 
 
 
 
5
  with open("svm_model.pkl", "rb") as f:
6
- model = pickle.load(f)
 
 
 
 
 
 
 
7
 
8
  # Prediction function
9
- def predict(sex, pregnant, TT4, T3, T4U, FTI, TSH):
10
  try:
11
- # Convert inputs to appropriate types
 
 
12
  sex = int(sex)
13
  pregnant = int(pregnant)
 
14
  TT4 = float(TT4)
15
  T3 = float(T3)
16
  T4U = float(T4U)
17
  FTI = float(FTI)
18
  TSH = float(TSH)
19
- prediction = model.predict([[sex, pregnant, TT4, T3, T4U, FTI, TSH]])
 
20
  label_map = {0: "Hyperthyroid", 1: "Hypothyroid", 2: "Negative"}
21
- return f"Prediction: {label_map.get(prediction[0], 'Unknown')}"
22
  except Exception as e:
23
  return f"Error: {str(e)}"
24
 
@@ -26,8 +43,10 @@ def predict(sex, pregnant, TT4, T3, T4U, FTI, TSH):
26
  demo = gr.Interface(
27
  fn=predict,
28
  inputs=[
 
29
  gr.Radio([0, 1], label="Sex (0: Female, 1: Male)"),
30
  gr.Radio([0, 1], label="Pregnant (0: No, 1: Yes)"),
 
31
  gr.Number(label="TT4"),
32
  gr.Number(label="T3"),
33
  gr.Number(label="T4U"),
@@ -35,8 +54,8 @@ demo = gr.Interface(
35
  gr.Number(label="TSH"),
36
  ],
37
  outputs="text",
38
- title="Hyperthyroid Prediction",
39
- description="Enter patient info to predict thyroid condition using SVM model."
40
  )
41
 
42
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import pickle
3
 
4
+ # Load all models
5
+ with open("knn_model.pkl", "rb") as f:
6
+ knn_model = pickle.load(f)
7
+
8
+ with open("rf_model.pkl", "rb") as f:
9
+ rf_model = pickle.load(f)
10
+
11
  with open("svm_model.pkl", "rb") as f:
12
+ svm_model = pickle.load(f)
13
+
14
+ # Map model name to actual model
15
+ model_map = {
16
+ "KNN": knn_model,
17
+ "Random Forest": rf_model,
18
+ "SVM": svm_model
19
+ }
20
 
21
  # Prediction function
22
+ def predict(model_name, sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH):
23
  try:
24
+ model = model_map[model_name]
25
+
26
+ # Ensure inputs are correctly typed
27
  sex = int(sex)
28
  pregnant = int(pregnant)
29
+ on_thyroxine = int(on_thyroxine)
30
  TT4 = float(TT4)
31
  T3 = float(T3)
32
  T4U = float(T4U)
33
  FTI = float(FTI)
34
  TSH = float(TSH)
35
+
36
+ prediction = model.predict([[sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH]])
37
  label_map = {0: "Hyperthyroid", 1: "Hypothyroid", 2: "Negative"}
38
+ return f"Prediction using {model_name}: {label_map.get(prediction[0], 'Unknown')}"
39
  except Exception as e:
40
  return f"Error: {str(e)}"
41
 
 
43
  demo = gr.Interface(
44
  fn=predict,
45
  inputs=[
46
+ gr.Dropdown(["SVM", "KNN", "Random Forest"], label="Select Model"),
47
  gr.Radio([0, 1], label="Sex (0: Female, 1: Male)"),
48
  gr.Radio([0, 1], label="Pregnant (0: No, 1: Yes)"),
49
+ gr.Radio([0, 1], label="On Thyroxine (0: No, 1: Yes)"),
50
  gr.Number(label="TT4"),
51
  gr.Number(label="T3"),
52
  gr.Number(label="T4U"),
 
54
  gr.Number(label="TSH"),
55
  ],
56
  outputs="text",
57
+ title="Hyperthyroid Prediction (Multi-Model)",
58
+ description="Select a model and input medical features to predict thyroid condition."
59
  )
60
 
61
  if __name__ == "__main__":