Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from typing import Any, Dict, List | |
| from src.registry import get_model_display_names, get_model | |
| APP_TITLE = "Machine Learning CS 6140 Project: Pet Recognizer" | |
| TOP_K_DEFAULT = 5 | |
| DARK_CSS = """ | |
| body { | |
| background-color: #0f172a !important; | |
| } | |
| .gradio-container { | |
| background-color: #0f172a !important; | |
| color: #e5e7eb !important; | |
| } | |
| h1, h2, h3, h4, p, li, label { | |
| color: #e5e7eb !important; | |
| } | |
| a { | |
| color: #60a5fa !important; | |
| } | |
| .gr-box { | |
| background-color: #020617 !important; | |
| border-radius: 10px; | |
| } | |
| .gr-button { | |
| background-color: #1e293b !important; | |
| color: #e5e7eb !important; | |
| } | |
| .gr-button:hover { | |
| background-color: #334155 !important; | |
| } | |
| """ | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def format_topk_for_table(top_k: List[Dict[str, Any]]) -> List[List[Any]]: | |
| rows = [] | |
| for rank, entry in enumerate(top_k, start=1): | |
| class_name = entry.get("class_name", f"id={entry.get('class_id', '?')}") | |
| prob = entry.get("probability", 0.0) | |
| rows.append([rank, class_name, round(float(prob) * 100.0, 2)]) | |
| return rows | |
| def run_inference(model_id: str, image) -> Dict[str, Any]: | |
| if image is None: | |
| return { | |
| "main_text": "Please upload an image first.", | |
| "topk_table": [], | |
| } | |
| model = get_model(model_id) | |
| result = model.predict(image, top_k=TOP_K_DEFAULT) | |
| class_name = result.get("class_name", "Unknown") | |
| class_id = result.get("class_id", "N/A") | |
| top_k = result.get("top_k", []) | |
| main_text = ( | |
| f"**Predicted Class:** {class_name} \n" | |
| f"**Class ID:** {class_id}" | |
| ) | |
| return { | |
| "main_text": main_text, | |
| "topk_table": format_topk_for_table(top_k), | |
| } | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| def build_demo() -> gr.Blocks: | |
| model_display_names = get_model_display_names() | |
| name_to_id = {v: k for k, v in model_display_names.items()} | |
| default_display_name = next(iter(name_to_id.keys())) | |
| with gr.Blocks(css=DARK_CSS) as demo: | |
| # Title | |
| gr.Markdown( | |
| f""" | |
| # {APP_TITLE} | |
| This project demonstrates **pet breed recognition** using the | |
| **Oxford-IIIT Pet Dataset**, comparing **classical machine learning models** | |
| (Logistic Regression, SVM) with **deep feature-based models** | |
| (Pretrained ResNet18). | |
| **Dataset & Supported Breeds** | |
| The models are trained on **37 cat and dog breeds** from the Oxford-IIIT Pet Dataset. | |
| https://www.robots.ox.ac.uk/~vgg/data/pets/ | |
| """ | |
| ) | |
| # Instructions | |
| gr.Markdown( | |
| """ | |
| ## Instructions | |
| 1. **Upload** a clear, close-up image of a **cat or dog** belonging to one of the supported breeds | |
| 2. **Select a model** to run the recognition: | |
| - **LR / SVM** β Expected to perform poorly on raw pixel inputs | |
| - **ResNet-based models** β Use pretrained deep visual features and produce much better results | |
| 3. Click **Run Identification** to view the **Top-5 predictions** | |
| """ | |
| ) | |
| with gr.Row(): | |
| # Left column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Select Model & Upload Image") | |
| model_dropdown = gr.Dropdown( | |
| choices=list(name_to_id.keys()), | |
| value=default_display_name, | |
| label="Select Model", | |
| ) | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload your pet image (JPEG / PNG)", | |
| ) | |
| run_button = gr.Button("Run Identification") | |
| # Right column | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model Prediction") | |
| main_output = gr.Markdown( | |
| value="Prediction will appear here.", | |
| ) | |
| topk_output = gr.Dataframe( | |
| headers=["Rank", "Class Name", "Probability (%)"], | |
| datatype=["number", "str", "number"], | |
| column_count=3, | |
| label=f"Top-{TOP_K_DEFAULT} Predictions", | |
| ) | |
| # Button wiring | |
| def _gradio_infer(selected_display_name, img): | |
| model_id = name_to_id[selected_display_name] | |
| result = run_inference(model_id, img) | |
| return result["main_text"], result["topk_table"] | |
| run_button.click( | |
| fn=_gradio_infer, | |
| inputs=[model_dropdown, image_input], | |
| outputs=[main_output, topk_output], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() | |