tofighi's picture
Update app.py
509ecdc verified
import gradio as gr
import numpy as np
import joblib
import pickle
import os
# Path to the trained model (place your cancer_forest.pkl here)
MODEL_PATH = "cancer_forest.pkl"
# Try loading the model using joblib first, then pickle as a fallback.
model = None
_load_error = None
if os.path.exists(MODEL_PATH):
try:
model = joblib.load(MODEL_PATH)
except Exception as e_joblib:
try:
with open(MODEL_PATH, "rb") as f:
model = pickle.load(f)
except Exception as e_pickle:
_load_error = f"Failed to load model with joblib ({e_joblib}) and pickle ({e_pickle})"
else:
_load_error = f"Model file not found at '{MODEL_PATH}'. Please upload the trained model."
# Map class indices to human-readable names. sklearn breast_cancer: 0=malignant, 1=benign
TARGET_NAMES = ["malignant", "benign"]
def predict_breast(mean_concave_points: float, worst_concave_points: float, worst_area: float):
"""
Predict malignant/benign using a RandomForest model trained on the top-3 features:
1. mean concave points
2. worst concave points
3. worst area
Returns:
- predicted label (string)
- dict of probabilities {label: probability}
"""
if _load_error is not None:
return "MODEL LOAD ERROR", {"error": _load_error}
if model is None:
return "MODEL NOT LOADED", {"error": "Model is None after attempted load."}
# Input vector must match the training order used in your notebook
arr = np.array([[mean_concave_points, worst_concave_points, worst_area]], dtype=float)
try:
pred_raw = model.predict(arr)[0]
# sklearn often returns numpy.int64; make it int for indexing
pred_idx = int(pred_raw)
except Exception as e:
return "PREDICTION ERROR", {"error": f"model.predict failed: {e}"}
# Try to produce probabilities. If predict_proba isn't available, return deterministic prob.
try:
proba_arr = model.predict_proba(arr)[0]
# Respect model.classes_ ordering if present
if hasattr(model, "classes_"):
class_prob_map = {int(cls): float(proba_arr[i]) for i, cls in enumerate(model.classes_)}
proba = {
TARGET_NAMES[0]: float(class_prob_map.get(0, 0.0)),
TARGET_NAMES[1]: float(class_prob_map.get(1, 0.0)),
}
else:
# fallback assume order [0,1]
proba = {
TARGET_NAMES[0]: float(proba_arr[0]),
TARGET_NAMES[1]: float(proba_arr[1]) if len(proba_arr) > 1 else 0.0,
}
except Exception:
# If predict_proba fails or not available, give deterministic prob 1.0 for predicted class
proba = {TARGET_NAMES[i]: (1.0 if i == pred_idx else 0.0) for i in range(len(TARGET_NAMES))}
predicted_label = TARGET_NAMES[pred_idx] if 0 <= pred_idx < len(TARGET_NAMES) else str(pred_idx)
return predicted_label, proba
# Build Gradio interface
with gr.Blocks() as demo:
demo_title = "๐Ÿฉบ Breast Cancer Detector โ€” Random Forest (Top 3 Features)"
demo_sub = (
"Predict whether a tumor is malignant or benign using a RandomForest model "
"trained on the top 3 features from sklearn's breast cancer dataset.\n\n"
"Expected input order: mean concave points, worst concave points, worst area."
)
gr.Markdown(f"# {demo_title}")
gr.Markdown(demo_sub)
with gr.Row():
with gr.Column():
mean_concave_points = gr.Number(label="Mean Concave Points", value=0.0)
worst_concave_points = gr.Number(label="Worst Concave Points", value=0.0)
worst_area = gr.Number(label="Worst Area", value=0.0)
predict_btn = gr.Button("Predict")
output_class = gr.Label(label="Predicted Class")
output_proba = gr.JSON(label="Probabilities")
predict_btn.click(
fn=predict_breast,
inputs=[mean_concave_points, worst_concave_points, worst_area],
outputs=[output_class, output_proba],
)
with gr.Column():
api_markdown = (
"## ๐Ÿ“– API Usage (example for Hugging Face Spaces)\n\n"
"When deployed to a Hugging Face Space, the Gradio app provides a POST /predict endpoint.\n\n"
"### JSON Request Example\n\n"
"```json\n"
"{\n"
' "mean_concave_points": 0.1,\n'
' "worst_concave_points": 0.2,\n'
' "worst_area": 800.0\n'
"}\n"
"```\n\n"
"### Python Example\n\n"
"```python\n"
"import requests\n"
"url = \"https://your-hf-space-name.hf.space/predict\"\n"
"data = {\n"
" \"mean_concave_points\": 0.1,\n"
" \"worst_concave_points\": 0.2,\n"
" \"worst_area\": 800.0\n"
"}\n"
"resp = requests.post(url, json=data)\n"
"print(resp.json())\n"
"```\n\n"
"### cURL Example\n\n"
"```bash\n"
"curl -X POST \"https://your-hf-space-name.hf.space/predict\" \\\n"
" -H \"Content-Type: application/json\" \\\n"
" -d '{\"mean_concave_points\":0.1,\"worst_concave_points\":0.2,\"worst_area\":800.0}'\n"
"```\n\n"
"Ensure the trained model file `cancer_forest.pkl` (RandomForestClassifier trained on:\n"
"`mean concave points`, `worst concave points`, `worst area` in that order) is uploaded\n"
"to the same directory as this app."
)
gr.Markdown(api_markdown)
# Launch the app (when run as a script / HF Space entrypoint)
if __name__ == "__main__":
demo.launch()