import gradio as gr from typing import Any, Dict, List from src.registry import get_model_display_names, get_model APP_TITLE = "Machine Learning CS 6140 Project: Pet Recognizer" TOP_K_DEFAULT = 5 DARK_CSS = """ body { background-color: #0f172a !important; } .gradio-container { background-color: #0f172a !important; color: #e5e7eb !important; } h1, h2, h3, h4, p, li, label { color: #e5e7eb !important; } a { color: #60a5fa !important; } .gr-box { background-color: #020617 !important; border-radius: 10px; } .gr-button { background-color: #1e293b !important; color: #e5e7eb !important; } .gr-button:hover { background-color: #334155 !important; } """ # ----------------------------- # Helpers # ----------------------------- def format_topk_for_table(top_k: List[Dict[str, Any]]) -> List[List[Any]]: rows = [] for rank, entry in enumerate(top_k, start=1): class_name = entry.get("class_name", f"id={entry.get('class_id', '?')}") prob = entry.get("probability", 0.0) rows.append([rank, class_name, round(float(prob) * 100.0, 2)]) return rows def run_inference(model_id: str, image) -> Dict[str, Any]: if image is None: return { "main_text": "Please upload an image first.", "topk_table": [], } model = get_model(model_id) result = model.predict(image, top_k=TOP_K_DEFAULT) class_name = result.get("class_name", "Unknown") class_id = result.get("class_id", "N/A") top_k = result.get("top_k", []) main_text = ( f"**Predicted Class:** {class_name} \n" f"**Class ID:** {class_id}" ) return { "main_text": main_text, "topk_table": format_topk_for_table(top_k), } # ----------------------------- # UI # ----------------------------- def build_demo() -> gr.Blocks: model_display_names = get_model_display_names() name_to_id = {v: k for k, v in model_display_names.items()} default_display_name = next(iter(name_to_id.keys())) with gr.Blocks(css=DARK_CSS) as demo: # Title gr.Markdown( f""" # {APP_TITLE} This project demonstrates **pet breed recognition** using the **Oxford-IIIT Pet Dataset**, comparing **classical machine learning models** (Logistic Regression, SVM) with **deep feature-based models** (Pretrained ResNet18). **Dataset & Supported Breeds** The models are trained on **37 cat and dog breeds** from the Oxford-IIIT Pet Dataset. https://www.robots.ox.ac.uk/~vgg/data/pets/ """ ) # Instructions gr.Markdown( """ ## Instructions 1. **Upload** a clear, close-up image of a **cat or dog** belonging to one of the supported breeds 2. **Select a model** to run the recognition: - **LR / SVM** → Expected to perform poorly on raw pixel inputs - **ResNet-based models** → Use pretrained deep visual features and produce much better results 3. Click **Run Identification** to view the **Top-5 predictions** """ ) with gr.Row(): # Left column with gr.Column(scale=1): gr.Markdown("### Select Model & Upload Image") model_dropdown = gr.Dropdown( choices=list(name_to_id.keys()), value=default_display_name, label="Select Model", ) image_input = gr.Image( type="pil", label="Upload your pet image (JPEG / PNG)", ) run_button = gr.Button("Run Identification") # Right column with gr.Column(scale=1): gr.Markdown("### Model Prediction") main_output = gr.Markdown( value="Prediction will appear here.", ) topk_output = gr.Dataframe( headers=["Rank", "Class Name", "Probability (%)"], datatype=["number", "str", "number"], column_count=3, label=f"Top-{TOP_K_DEFAULT} Predictions", ) # Button wiring def _gradio_infer(selected_display_name, img): model_id = name_to_id[selected_display_name] result = run_inference(model_id, img) return result["main_text"], result["topk_table"] run_button.click( fn=_gradio_infer, inputs=[model_dropdown, image_input], outputs=[main_output, topk_output], ) return demo if __name__ == "__main__": demo = build_demo() demo.launch()