Shashwat98's picture
Update app.py
5f8adf2 verified
import gradio as gr
from typing import Any, Dict, List
from src.registry import get_model_display_names, get_model
APP_TITLE = "Machine Learning CS 6140 Project: Pet Recognizer"
TOP_K_DEFAULT = 5
DARK_CSS = """
body {
background-color: #0f172a !important;
}
.gradio-container {
background-color: #0f172a !important;
color: #e5e7eb !important;
}
h1, h2, h3, h4, p, li, label {
color: #e5e7eb !important;
}
a {
color: #60a5fa !important;
}
.gr-box {
background-color: #020617 !important;
border-radius: 10px;
}
.gr-button {
background-color: #1e293b !important;
color: #e5e7eb !important;
}
.gr-button:hover {
background-color: #334155 !important;
}
"""
# -----------------------------
# Helpers
# -----------------------------
def format_topk_for_table(top_k: List[Dict[str, Any]]) -> List[List[Any]]:
rows = []
for rank, entry in enumerate(top_k, start=1):
class_name = entry.get("class_name", f"id={entry.get('class_id', '?')}")
prob = entry.get("probability", 0.0)
rows.append([rank, class_name, round(float(prob) * 100.0, 2)])
return rows
def run_inference(model_id: str, image) -> Dict[str, Any]:
if image is None:
return {
"main_text": "Please upload an image first.",
"topk_table": [],
}
model = get_model(model_id)
result = model.predict(image, top_k=TOP_K_DEFAULT)
class_name = result.get("class_name", "Unknown")
class_id = result.get("class_id", "N/A")
top_k = result.get("top_k", [])
main_text = (
f"**Predicted Class:** {class_name} \n"
f"**Class ID:** {class_id}"
)
return {
"main_text": main_text,
"topk_table": format_topk_for_table(top_k),
}
# -----------------------------
# UI
# -----------------------------
def build_demo() -> gr.Blocks:
model_display_names = get_model_display_names()
name_to_id = {v: k for k, v in model_display_names.items()}
default_display_name = next(iter(name_to_id.keys()))
with gr.Blocks(css=DARK_CSS) as demo:
# Title
gr.Markdown(
f"""
# {APP_TITLE}
This project demonstrates **pet breed recognition** using the
**Oxford-IIIT Pet Dataset**, comparing **classical machine learning models**
(Logistic Regression, SVM) with **deep feature-based models**
(Pretrained ResNet18).
**Dataset & Supported Breeds**
The models are trained on **37 cat and dog breeds** from the Oxford-IIIT Pet Dataset.
https://www.robots.ox.ac.uk/~vgg/data/pets/
"""
)
# Instructions
gr.Markdown(
"""
## Instructions
1. **Upload** a clear, close-up image of a **cat or dog** belonging to one of the supported breeds
2. **Select a model** to run the recognition:
- **LR / SVM** β†’ Expected to perform poorly on raw pixel inputs
- **ResNet-based models** β†’ Use pretrained deep visual features and produce much better results
3. Click **Run Identification** to view the **Top-5 predictions**
"""
)
with gr.Row():
# Left column
with gr.Column(scale=1):
gr.Markdown("### Select Model & Upload Image")
model_dropdown = gr.Dropdown(
choices=list(name_to_id.keys()),
value=default_display_name,
label="Select Model",
)
image_input = gr.Image(
type="pil",
label="Upload your pet image (JPEG / PNG)",
)
run_button = gr.Button("Run Identification")
# Right column
with gr.Column(scale=1):
gr.Markdown("### Model Prediction")
main_output = gr.Markdown(
value="Prediction will appear here.",
)
topk_output = gr.Dataframe(
headers=["Rank", "Class Name", "Probability (%)"],
datatype=["number", "str", "number"],
column_count=3,
label=f"Top-{TOP_K_DEFAULT} Predictions",
)
# Button wiring
def _gradio_infer(selected_display_name, img):
model_id = name_to_id[selected_display_name]
result = run_inference(model_id, img)
return result["main_text"], result["topk_table"]
run_button.click(
fn=_gradio_infer,
inputs=[model_dropdown, image_input],
outputs=[main_output, topk_output],
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()