Shashwat98 commited on
Commit
5f8adf2
·
verified ·
1 Parent(s): f6c40fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -152
app.py CHANGED
@@ -1,152 +1,169 @@
1
- # ui/app.py
2
-
3
- import gradio as gr
4
- from typing import Any, Dict, List
5
-
6
- from src.registry import get_model_display_names, get_model
7
-
8
- APP_TITLE = "PetRecog – Oxford-IIIT Pet Identification"
9
- APP_DESC = (
10
- "Upload a pet image, choose a model, and compare predictions across "
11
- "classical (LR, SVM) and deep-feature (ResNet) models."
12
- )
13
- TOP_K_DEFAULT = 5
14
-
15
-
16
- def format_topk_for_table(top_k: List[Dict[str, Any]]) -> List[List[Any]]:
17
- """
18
- Convert the model's top_k list of dicts into a 2D list suitable for gr.Dataframe.
19
-
20
- Expected each entry in top_k to look like:
21
- { 'class_id': int, 'class_name': str, 'probability': float }
22
- """
23
- rows = []
24
- for rank, entry in enumerate(top_k, start=1):
25
- class_name = entry.get("class_name", f"id={entry.get('class_id', '?')}")
26
- prob = entry.get("probability", 0.0)
27
- rows.append([rank, class_name, round(float(prob) * 100.0, 2)])
28
- return rows
29
-
30
-
31
- def run_inference(model_id: str, image) -> Dict[str, Any]:
32
- """
33
- Wrapper called by Gradio.
34
-
35
- Inputs:
36
- - model_id: key from the registry
37
- - image: PIL image object from gr.Image (type='pil')
38
-
39
- Outputs (as a dict mapped to Gradio components in the UI):
40
- - main_text: formatted prediction string
41
- - topk_table: 2D list for gr.Dataframe
42
- """
43
- if image is None:
44
- return {
45
- "main_text": "⚠️ Please upload an image first.",
46
- "topk_table": [],
47
- }
48
-
49
- # Get the model instance (lazy-loaded via registry)
50
- model = get_model(model_id)
51
-
52
- # All models follow the shared predict API:
53
- # predict(PIL.Image, top_k=TOP_K_DEFAULT) -> {
54
- # 'class_id', 'class_name', 'probabilities', 'top_k'
55
- # }
56
- result = model.predict(image, top_k=TOP_K_DEFAULT)
57
-
58
- class_name = result.get("class_name", "Unknown")
59
- class_id = result.get("class_id", "N/A")
60
- top_k = result.get("top_k", [])
61
-
62
- main_text = f"**Predicted Class:** {class_name} \n" f"**Class ID:** {class_id}"
63
-
64
- table = format_topk_for_table(top_k)
65
-
66
- return {
67
- "main_text": main_text,
68
- "topk_table": table,
69
- }
70
-
71
-
72
- def build_demo() -> gr.Blocks:
73
- model_display_names = get_model_display_names()
74
- # Gradio dropdown will show pretty display_name, but we need to map back to ids.
75
- id_to_name = model_display_names
76
- name_to_id = {v: k for k, v in id_to_name.items()}
77
-
78
- default_display_name = next(iter(name_to_id.keys())) if name_to_id else None
79
-
80
- with gr.Blocks(css="""
81
- body { background: #fbead8; }
82
- .noble-header { text-align: center; margin-bottom: 1.0rem; }
83
- .noble-title { font-size: 2.0rem; font-weight: 800; color: #5b3b27; }
84
- .noble-subtitle { font-size: 0.95rem; color: #7a5b45; }
85
- """) as demo:
86
- # Header
87
- with gr.Row(elem_classes="noble-header"):
88
- gr.Markdown(
89
- f"### {APP_TITLE}\n{APP_DESC}",
90
- elem_classes="noble-title"
91
- )
92
-
93
- with gr.Row():
94
- # Left column: controls
95
- with gr.Column(scale=1):
96
- gr.Markdown("#### 1️⃣ Select Model & Upload Image")
97
-
98
- model_dropdown = gr.Dropdown(
99
- choices=list(name_to_id.keys()),
100
- value=default_display_name,
101
- label="Select Model",
102
- )
103
-
104
- image_input = gr.Image(
105
- type="pil",
106
- label="Upload your pet image (JPEG/PNG)",
107
- )
108
-
109
- run_button = gr.Button("Run Identification")
110
-
111
- # Right column: output
112
- with gr.Column(scale=1):
113
- gr.Markdown("#### 2️⃣ Model Prediction")
114
-
115
- main_output = gr.Markdown(
116
- value="Prediction will appear here.",
117
- label="Prediction",
118
- )
119
-
120
- topk_output = gr.Dataframe(
121
- headers=["Rank", "Class Name", "Probability (%)"],
122
- datatype=["number", "str", "number"],
123
- col_count=(3, "fixed"),
124
- label=f"Top-{TOP_K_DEFAULT} Predictions",
125
- )
126
-
127
- # Wiring: button click -> inference
128
- def _gradio_infer(selected_display_name, img):
129
- if selected_display_name is None:
130
- return {
131
- main_output: "⚠️ Please select a model.",
132
- topk_output: [],
133
- }
134
- model_id = name_to_id[selected_display_name]
135
- result = run_inference(model_id, img)
136
- return {
137
- main_output: result["main_text"],
138
- topk_output: result["topk_table"],
139
- }
140
-
141
- run_button.click(
142
- fn=_gradio_infer,
143
- inputs=[model_dropdown, image_input],
144
- outputs=[main_output, topk_output],
145
- )
146
-
147
- return demo
148
-
149
-
150
- if __name__ == "__main__":
151
- demo = build_demo()
152
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Any, Dict, List
3
+
4
+ from src.registry import get_model_display_names, get_model
5
+
6
+ APP_TITLE = "Machine Learning CS 6140 Project: Pet Recognizer"
7
+ TOP_K_DEFAULT = 5
8
+
9
+ DARK_CSS = """
10
+ body {
11
+ background-color: #0f172a !important;
12
+ }
13
+
14
+ .gradio-container {
15
+ background-color: #0f172a !important;
16
+ color: #e5e7eb !important;
17
+ }
18
+
19
+ h1, h2, h3, h4, p, li, label {
20
+ color: #e5e7eb !important;
21
+ }
22
+
23
+ a {
24
+ color: #60a5fa !important;
25
+ }
26
+
27
+ .gr-box {
28
+ background-color: #020617 !important;
29
+ border-radius: 10px;
30
+ }
31
+
32
+ .gr-button {
33
+ background-color: #1e293b !important;
34
+ color: #e5e7eb !important;
35
+ }
36
+
37
+ .gr-button:hover {
38
+ background-color: #334155 !important;
39
+ }
40
+ """
41
+
42
+
43
+ # -----------------------------
44
+ # Helpers
45
+ # -----------------------------
46
+ def format_topk_for_table(top_k: List[Dict[str, Any]]) -> List[List[Any]]:
47
+ rows = []
48
+ for rank, entry in enumerate(top_k, start=1):
49
+ class_name = entry.get("class_name", f"id={entry.get('class_id', '?')}")
50
+ prob = entry.get("probability", 0.0)
51
+ rows.append([rank, class_name, round(float(prob) * 100.0, 2)])
52
+ return rows
53
+
54
+
55
+ def run_inference(model_id: str, image) -> Dict[str, Any]:
56
+ if image is None:
57
+ return {
58
+ "main_text": "Please upload an image first.",
59
+ "topk_table": [],
60
+ }
61
+
62
+ model = get_model(model_id)
63
+ result = model.predict(image, top_k=TOP_K_DEFAULT)
64
+
65
+ class_name = result.get("class_name", "Unknown")
66
+ class_id = result.get("class_id", "N/A")
67
+ top_k = result.get("top_k", [])
68
+
69
+ main_text = (
70
+ f"**Predicted Class:** {class_name} \n"
71
+ f"**Class ID:** {class_id}"
72
+ )
73
+
74
+ return {
75
+ "main_text": main_text,
76
+ "topk_table": format_topk_for_table(top_k),
77
+ }
78
+
79
+
80
+ # -----------------------------
81
+ # UI
82
+ # -----------------------------
83
+ def build_demo() -> gr.Blocks:
84
+ model_display_names = get_model_display_names()
85
+ name_to_id = {v: k for k, v in model_display_names.items()}
86
+ default_display_name = next(iter(name_to_id.keys()))
87
+
88
+ with gr.Blocks(css=DARK_CSS) as demo:
89
+
90
+ # Title
91
+ gr.Markdown(
92
+ f"""
93
+ # {APP_TITLE}
94
+
95
+ This project demonstrates **pet breed recognition** using the
96
+ **Oxford-IIIT Pet Dataset**, comparing **classical machine learning models**
97
+ (Logistic Regression, SVM) with **deep feature-based models**
98
+ (Pretrained ResNet18).
99
+
100
+ **Dataset & Supported Breeds**
101
+ The models are trained on **37 cat and dog breeds** from the Oxford-IIIT Pet Dataset.
102
+ https://www.robots.ox.ac.uk/~vgg/data/pets/
103
+ """
104
+ )
105
+
106
+ # Instructions
107
+ gr.Markdown(
108
+ """
109
+ ## Instructions
110
+
111
+ 1. **Upload** a clear, close-up image of a **cat or dog** belonging to one of the supported breeds
112
+ 2. **Select a model** to run the recognition:
113
+ - **LR / SVM** → Expected to perform poorly on raw pixel inputs
114
+ - **ResNet-based models** → Use pretrained deep visual features and produce much better results
115
+ 3. Click **Run Identification** to view the **Top-5 predictions**
116
+ """
117
+ )
118
+
119
+ with gr.Row():
120
+ # Left column
121
+ with gr.Column(scale=1):
122
+ gr.Markdown("### Select Model & Upload Image")
123
+
124
+ model_dropdown = gr.Dropdown(
125
+ choices=list(name_to_id.keys()),
126
+ value=default_display_name,
127
+ label="Select Model",
128
+ )
129
+
130
+ image_input = gr.Image(
131
+ type="pil",
132
+ label="Upload your pet image (JPEG / PNG)",
133
+ )
134
+
135
+ run_button = gr.Button("Run Identification")
136
+
137
+ # Right column
138
+ with gr.Column(scale=1):
139
+ gr.Markdown("### Model Prediction")
140
+
141
+ main_output = gr.Markdown(
142
+ value="Prediction will appear here.",
143
+ )
144
+
145
+ topk_output = gr.Dataframe(
146
+ headers=["Rank", "Class Name", "Probability (%)"],
147
+ datatype=["number", "str", "number"],
148
+ column_count=3,
149
+ label=f"Top-{TOP_K_DEFAULT} Predictions",
150
+ )
151
+
152
+ # Button wiring
153
+ def _gradio_infer(selected_display_name, img):
154
+ model_id = name_to_id[selected_display_name]
155
+ result = run_inference(model_id, img)
156
+ return result["main_text"], result["topk_table"]
157
+
158
+ run_button.click(
159
+ fn=_gradio_infer,
160
+ inputs=[model_dropdown, image_input],
161
+ outputs=[main_output, topk_output],
162
+ )
163
+
164
+ return demo
165
+
166
+
167
+ if __name__ == "__main__":
168
+ demo = build_demo()
169
+ demo.launch()