SebastianAndreu's picture
Create app.py
2abf59b verified
# ================================================================
# HW3: Driving a Stop Sign Image Classifier with Gradio
#
# Author: Your Name
# Course: 24679 - Designing and Deploying AI/ML Systems
# Dataset: Binary stop sign dataset (class_1 = stop sign, class_0 = not stop sign)
# Task: Image classification deployed via Hugging Face Space
#
# Acknowledgments:
# - Model trained by a classmate in Homework 2
# - Deployment scaffold and documentation supported with AI assistance (ChatGPT, OpenAI)
# - Reference: Class-provided notebook "image gradio.ipynb"
# ================================================================
import os # For reading environment variables
import shutil # For directory cleanup
import zipfile # For extracting model archives
import pathlib # For path manipulations
import tempfile # For creating temporary files/directories
import gradio # For interactive UI
import pandas # For tabular data handling
import PIL.Image # For image I/O
import huggingface_hub # For downloading model assets
import autogluon.multimodal # For loading AutoGluon image classifier
# -------------------------
# Hugging Face model setup
# -------------------------
MODEL_REPO_ID = "cassieli226/sign-identification-automl" # <- update with teammate’s repo
ZIP_FILENAME = "autogluon_predictor_dir.zip"
HF_TOKEN = os.getenv("HF_TOKEN", None)
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
def _prepare_predictor_dir() -> str:
"""Download and extract the AutoGluon predictor directory from Hugging Face."""
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
token=HF_TOKEN,
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR)
# -------------------------
# Class labels
# -------------------------
CLASS_LABELS = {
0: "🚦 Not a Stop Sign",
1: "πŸ›‘ Stop Sign"
}
def _human_label(c):
try:
ci = int(c)
return CLASS_LABELS.get(ci, str(c))
except Exception:
return CLASS_LABELS.get(c, str(c))
# -------------------------
# Prediction + preprocessing
# -------------------------
def do_predict(pil_img: PIL.Image.Image):
"""Run prediction on an uploaded image. Returns original, preprocessed, and probabilities."""
if pil_img is None:
return None, None, "No image provided.", {}
try:
# --- Save to temp path for AutoGluon ---
tmpdir = pathlib.Path(tempfile.mkdtemp())
img_path = tmpdir / "input.png"
pil_img.save(img_path)
# --- Preprocess (resize to 224x224 for visualization only) ---
preprocessed_img = pil_img.copy()
preprocessed_img = preprocessed_img.resize((224, 224))
# --- Build input dataframe for AutoGluon ---
df = pandas.DataFrame({"image": [str(img_path)]})
# --- Predict probabilities ---
proba_df = PREDICTOR.predict_proba(df)
# Rename for clarity
proba_df = proba_df.rename(columns={
0: "🚦 Not a Stop Sign (0)",
1: "πŸ›‘ Stop Sign (1)"
})
row = proba_df.iloc[0]
pretty_dict = {
"🚦 Not a Stop Sign": float(row.get("🚦 Not a Stop Sign (0)", 0.0)),
"πŸ›‘ Stop Sign": float(row.get("πŸ›‘ Stop Sign (1)", 0.0)),
}
return pil_img, preprocessed_img, "Prediction complete", pretty_dict
except Exception as e:
return None, None, f"Error: {str(e)}", {}
# -------------------------
# Example images
# -------------------------
EXAMPLES = [
["https://upload.wikimedia.org/wikipedia/commons/thumb/f/f9/STOP_sign.jpg/640px-STOP_sign.jpg"],
["https://upload.wikimedia.org/wikipedia/commons/1/19/Swiss_Frutiger_Traffic_Sign.jpg"]
]
# -------------------------
# Gradio interface
# -------------------------
with gradio.Blocks() as demo:
gradio.Markdown("# πŸ›‘ Stop Sign Detector")
gradio.Markdown("""
Upload a road scene or traffic sign image, and this app will classify whether
a stop sign is present. The interface shows both the original image and the
preprocessed image (224x224) that the model actually sees.
""")
with gradio.Row():
image_in = gradio.Image(
type="pil",
label="Upload or capture an image",
sources=["upload", "webcam"],
image_mode="RGB"
)
with gradio.Row():
orig_out = gradio.Image(type="pil", label="Original Image")
preproc_out = gradio.Image(type="pil", label="Preprocessed Image (224x224)")
status_out = gradio.Textbox(label="Status")
proba_pretty = gradio.Label(num_top_classes=2, label="Class probabilities")
image_in.change(
fn=do_predict,
inputs=[image_in],
outputs=[orig_out, preproc_out, status_out, proba_pretty]
)
gradio.Examples(
examples=EXAMPLES,
inputs=[image_in],
label="Representative examples",
examples_per_page=3,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()