File size: 5,656 Bytes
2abf59b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# ================================================================
# HW3: Driving a Stop Sign Image Classifier with Gradio
#
# Author: Your Name
# Course: 24679 - Designing and Deploying AI/ML Systems
# Dataset: Binary stop sign dataset (class_1 = stop sign, class_0 = not stop sign)
# Task: Image classification deployed via Hugging Face Space
#
# Acknowledgments:
# - Model trained by a classmate in Homework 2
# - Deployment scaffold and documentation supported with AI assistance (ChatGPT, OpenAI)
# - Reference: Class-provided notebook "image gradio.ipynb"
# ================================================================
import os # For reading environment variables
import shutil # For directory cleanup
import zipfile # For extracting model archives
import pathlib # For path manipulations
import tempfile # For creating temporary files/directories
import gradio # For interactive UI
import pandas # For tabular data handling
import PIL.Image # For image I/O
import huggingface_hub # For downloading model assets
import autogluon.multimodal # For loading AutoGluon image classifier
# -------------------------
# Hugging Face model setup
# -------------------------
MODEL_REPO_ID = "cassieli226/sign-identification-automl" # <- update with teammateβs repo
ZIP_FILENAME = "autogluon_predictor_dir.zip"
HF_TOKEN = os.getenv("HF_TOKEN", None)
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
def _prepare_predictor_dir() -> str:
"""Download and extract the AutoGluon predictor directory from Hugging Face."""
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
token=HF_TOKEN,
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR)
# -------------------------
# Class labels
# -------------------------
CLASS_LABELS = {
0: "π¦ Not a Stop Sign",
1: "π Stop Sign"
}
def _human_label(c):
try:
ci = int(c)
return CLASS_LABELS.get(ci, str(c))
except Exception:
return CLASS_LABELS.get(c, str(c))
# -------------------------
# Prediction + preprocessing
# -------------------------
def do_predict(pil_img: PIL.Image.Image):
"""Run prediction on an uploaded image. Returns original, preprocessed, and probabilities."""
if pil_img is None:
return None, None, "No image provided.", {}
try:
# --- Save to temp path for AutoGluon ---
tmpdir = pathlib.Path(tempfile.mkdtemp())
img_path = tmpdir / "input.png"
pil_img.save(img_path)
# --- Preprocess (resize to 224x224 for visualization only) ---
preprocessed_img = pil_img.copy()
preprocessed_img = preprocessed_img.resize((224, 224))
# --- Build input dataframe for AutoGluon ---
df = pandas.DataFrame({"image": [str(img_path)]})
# --- Predict probabilities ---
proba_df = PREDICTOR.predict_proba(df)
# Rename for clarity
proba_df = proba_df.rename(columns={
0: "π¦ Not a Stop Sign (0)",
1: "π Stop Sign (1)"
})
row = proba_df.iloc[0]
pretty_dict = {
"π¦ Not a Stop Sign": float(row.get("π¦ Not a Stop Sign (0)", 0.0)),
"π Stop Sign": float(row.get("π Stop Sign (1)", 0.0)),
}
return pil_img, preprocessed_img, "Prediction complete", pretty_dict
except Exception as e:
return None, None, f"Error: {str(e)}", {}
# -------------------------
# Example images
# -------------------------
EXAMPLES = [
["https://upload.wikimedia.org/wikipedia/commons/thumb/f/f9/STOP_sign.jpg/640px-STOP_sign.jpg"],
["https://upload.wikimedia.org/wikipedia/commons/1/19/Swiss_Frutiger_Traffic_Sign.jpg"]
]
# -------------------------
# Gradio interface
# -------------------------
with gradio.Blocks() as demo:
gradio.Markdown("# π Stop Sign Detector")
gradio.Markdown("""
Upload a road scene or traffic sign image, and this app will classify whether
a stop sign is present. The interface shows both the original image and the
preprocessed image (224x224) that the model actually sees.
""")
with gradio.Row():
image_in = gradio.Image(
type="pil",
label="Upload or capture an image",
sources=["upload", "webcam"],
image_mode="RGB"
)
with gradio.Row():
orig_out = gradio.Image(type="pil", label="Original Image")
preproc_out = gradio.Image(type="pil", label="Preprocessed Image (224x224)")
status_out = gradio.Textbox(label="Status")
proba_pretty = gradio.Label(num_top_classes=2, label="Class probabilities")
image_in.change(
fn=do_predict,
inputs=[image_in],
outputs=[orig_out, preproc_out, status_out, proba_pretty]
)
gradio.Examples(
examples=EXAMPLES,
inputs=[image_in],
label="Representative examples",
examples_per_page=3,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch() |