samarth-kamble commited on
Commit
b5160be
·
verified ·
1 Parent(s): 1a5e0e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -10
app.py CHANGED
@@ -1,36 +1,51 @@
1
  import gradio as gr
2
  import pickle
3
 
4
- # Load the SVM model
 
 
 
 
 
 
5
  with open("svm_model.pkl", "rb") as f:
6
  svm_model = pickle.load(f)
7
 
 
 
 
 
 
 
 
8
  # Prediction function
9
- def predict(sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH):
10
  try:
11
- # Ensure inputs are correctly typed
 
 
12
  sex = int(sex)
13
  pregnant = int(pregnant)
14
- on_thyroxine = int(on_thyroxine)
15
  TT4 = float(TT4)
16
  T3 = float(T3)
17
  T4U = float(T4U)
18
  FTI = float(FTI)
19
  TSH = float(TSH)
20
 
21
- prediction = svm_model.predict([[sex, pregnant, on_thyroxine, TT4, T3, T4U, FTI, TSH]])
 
22
  label_map = {0: "Hyperthyroid", 1: "Hypothyroid", 2: "Negative"}
23
- return f"Prediction: {label_map.get(prediction[0], 'Unknown')}"
24
  except Exception as e:
25
  return f"Error: {str(e)}"
26
 
27
  # Gradio UI
28
  demo = gr.Interface(
29
- fn=predict,
30
  inputs=[
 
31
  gr.Radio([0, 1], label="Sex (0: Female, 1: Male)"),
32
  gr.Radio([0, 1], label="Pregnant (0: No, 1: Yes)"),
33
- gr.Radio([0, 1], label="On Thyroxine (0: No, 1: Yes)"),
34
  gr.Number(label="TT4"),
35
  gr.Number(label="T3"),
36
  gr.Number(label="T4U"),
@@ -38,8 +53,8 @@ demo = gr.Interface(
38
  gr.Number(label="TSH"),
39
  ],
40
  outputs="text",
41
- title="Hyperthyroid Prediction using SVM",
42
- description="Enter 8 medical inputs to predict thyroid condition using an SVM model."
43
  )
44
 
45
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import pickle
3
 
4
+ # Load models using pickle
5
+ with open("knn_model.pkl", "rb") as f:
6
+ knn_model = pickle.load(f)
7
+
8
+ with open("rf_model.pkl", "rb") as f:
9
+ rf_model = pickle.load(f)
10
+
11
  with open("svm_model.pkl", "rb") as f:
12
  svm_model = pickle.load(f)
13
 
14
+ # Map for model selection
15
+ model_map = {
16
+ "KNN": knn_model,
17
+ "Random Forest": rf_model,
18
+ "SVM": svm_model
19
+ }
20
+
21
  # Prediction function
22
+ def dis_prediction(model_name, sex, pregnant, TT4, T3, T4U, FTI, TSH):
23
  try:
24
+ model = model_map[model_name]
25
+
26
+ # Convert input to correct types
27
  sex = int(sex)
28
  pregnant = int(pregnant)
 
29
  TT4 = float(TT4)
30
  T3 = float(T3)
31
  T4U = float(T4U)
32
  FTI = float(FTI)
33
  TSH = float(TSH)
34
 
35
+ # Predict
36
+ result = model.predict([[sex, pregnant, TT4, T3, T4U, FTI, TSH]])
37
  label_map = {0: "Hyperthyroid", 1: "Hypothyroid", 2: "Negative"}
38
+ return f"Prediction using {model_name}: {label_map.get(result[0], 'Unknown')}"
39
  except Exception as e:
40
  return f"Error: {str(e)}"
41
 
42
  # Gradio UI
43
  demo = gr.Interface(
44
+ fn=dis_prediction,
45
  inputs=[
46
+ gr.Dropdown(["SVM", "KNN", "Random Forest"], label="Select Model"),
47
  gr.Radio([0, 1], label="Sex (0: Female, 1: Male)"),
48
  gr.Radio([0, 1], label="Pregnant (0: No, 1: Yes)"),
 
49
  gr.Number(label="TT4"),
50
  gr.Number(label="T3"),
51
  gr.Number(label="T4U"),
 
53
  gr.Number(label="TSH"),
54
  ],
55
  outputs="text",
56
+ title="Hyperthyroid Prediction (with Pickle Models)",
57
+ description="Choose a model and enter patient data to predict thyroid condition."
58
  )
59
 
60
  if __name__ == "__main__":