tofighi commited on
Commit
dfbb01b
·
verified ·
1 Parent(s): 88bc905

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import joblib
4
+ import pickle
5
+ import os
6
+
7
+ # Expected filename (place your trained model here)
8
+ MODEL_PATH = "cancer_forest.pkl"
9
+
10
+ # Try to load the model robustly (joblib first, then pickle)
11
+ model = None
12
+ _load_error = None
13
+ if os.path.exists(MODEL_PATH):
14
+ try:
15
+ model = joblib.load(MODEL_PATH)
16
+ except Exception as e1:
17
+ try:
18
+ with open(MODEL_PATH, "rb") as f:
19
+ model = pickle.load(f)
20
+ except Exception as e2:
21
+ _load_error = f"Failed to load model with joblib ({e1}) and pickle ({e2})"
22
+ else:
23
+ _load_error = f"Model file not found at '{MODEL_PATH}'. Please upload the trained model."
24
+
25
+ # Define target names in the same order as sklearn's breast cancer dataset:
26
+ # 0 -> malignant, 1 -> benign
27
+ TARGET_NAMES = ["malignant", "benign"]
28
+
29
+ def predict_breast(mean_concave_points: float, worst_concave_points: float, worst_area: float):
30
+ """
31
+ Predict breast cancer (malignant/benign) using the trained RandomForest model
32
+ that expects the top 3 features in this order:
33
+ 1. mean concave points
34
+ 2. worst concave points
35
+ 3. worst area
36
+
37
+ Returns:
38
+ - predicted label (string)
39
+ - dict of probabilities {label: probability}
40
+ """
41
+ if _load_error is not None:
42
+ return "MODEL LOAD ERROR", {"error": _load_error}
43
+
44
+ # Ensure model exists and supports predict/predict_proba
45
+ if model is None:
46
+ return "MODEL NOT LOADED", {"error": "Model is None after attempted load."}
47
+
48
+ arr = np.array([[mean_concave_points, worst_concave_points, worst_area]], dtype=float)
49
+
50
+ try:
51
+ pred_idx = int(model.predict(arr)[0])
52
+ except Exception as e:
53
+ return "PREDICTION ERROR", {"error": f"model.predict failed: {e}"}
54
+
55
+ proba = None
56
+ try:
57
+ proba_arr = model.predict_proba(arr)[0]
58
+ # Some classifiers put classes_ in different order; map by model.classes_
59
+ if hasattr(model, "classes_"):
60
+ # build mapping from class label (0/1) to probability
61
+ class_prob_map = {int(cls): float(proba_arr[i]) for i, cls in enumerate(model.classes_)}
62
+ proba = {
63
+ TARGET_NAMES[0]: float(class_prob_map.get(0, 0.0)),
64
+ TARGET_NAMES[1]: float(class_prob_map.get(1, 0.0)),
65
+ }
66
+ else:
67
+ # fallback: assume order is [0,1]
68
+ proba = {
69
+ TARGET_NAMES[0]: float(proba_arr[0]),
70
+ TARGET_NAMES[1]: float(proba_arr[1]) if len(proba_arr) > 1 else 0.0,
71
+ }
72
+ except Exception:
73
+ # If predict_proba not available, return deterministic prediction with probability 1.0
74
+ proba = {TARGET_NAMES[i]: (1.0 if i == pred_idx else 0.0) for i in range(len(TARGET_NAMES))}
75
+
76
+ predicted_label = TARGET_NAMES[pred_idx] if 0 <= pred_idx < len(TARGET_NAMES) else str(pred_idx)
77
+ return predicted_label, proba
78
+
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown("# 🩺 Breast Cancer Detector — Random Forest (Top 3 Features)")
81
+ gr.Markdown("This app predicts whether a tumor is malignant or benign using a RandomForest model trained on the top 3 features from the sklearn breast cancer dataset.\n\n"
82
+ "**Expected input order**: mean concave points, worst concave points, worst area")
83
+
84
+ with gr.Row():
85
+ with gr.Column():
86
+ mean_concave_points = gr.Number(label="Mean Concave Points", value=0.0)
87
+ worst_concave_points = gr.Number(label="Worst Concave Points", value=0.0)
88
+ worst_area = gr.Number(label="Worst Area", value=0.0)
89
+
90
+ predict_btn = gr.Button("Predict")
91
+ output_class = gr.Label(label="Predicted Class")
92
+ output_proba = gr.JSON(label="Probabilities")
93
+
94
+ predict_btn.click(
95
+ fn=predict_breast,
96
+ inputs=[mean_concave_points, worst_concave_points, worst_area],
97
+ outputs=[output_class, output_proba]
98
+ )
99
+
100
+ with gr.Column():
101
+ gr.Markdown(
102
+ """
103
+ ## 📖 API Usage (example for Hugging Face Spaces)
104
+ When deployed to a Hugging Face Space, the Gradio app provides a POST /predict endpoint.
105
+
106
+ ### **API Endpoint (example)**