mariboo commited on
Commit
8fc11fa
·
verified ·
1 Parent(s): f3e93ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/python3
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import pandas as pd
6
+ import glob, os
7
+ import shoe_outlines_lib as sol
8
+ import matplotlib.pyplot as plt
9
+ import onnxruntime
10
+ import cv2
11
+
12
+ imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
13
+ imagenet_means = np.array(imagenet_stats[0], dtype=np.float32)[:, None, None]
14
+ imagenet_stds = np.array(imagenet_stats[1], dtype=np.float32)[:, None, None]
15
+ sz = (160, 256)
16
+
17
+ # Load the ONNX model
18
+ ort_session = onnxruntime.InferenceSession('shod-model.onnx')
19
+
20
+
21
+ def csv2image_fig(csv_file):
22
+ df = sol.csv2dfs([csv_file])[0]
23
+ fname = df.name
24
+ df = pd.concat([df, df.iloc[[0]]], ignore_index=True)
25
+ df = sol.norm_by_x(df)
26
+ image = sol.coordsdf2image(df)
27
+ fig = plt.figure(figsize=(2, 4))
28
+ plt.plot(df['x'], df['y'], marker='', linestyle='-', color='b', label='Line')
29
+ plt.fill(df['x'], df['y'], color='blue', alpha=0.2)
30
+ plt.axis('equal')
31
+ plt.axis('off')
32
+ plt.gca().invert_yaxis()
33
+ return image, fig, fname
34
+
35
+
36
+ def get_predictions(images, bs=8):
37
+ ''' class 0 is "No shoe", class 1 is "Shoe" '''
38
+
39
+ def _softmax(logits):
40
+ exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) # Stability trick
41
+ return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
42
+
43
+ if isinstance(images, np.ndarray): images = [images]
44
+
45
+ images = np.stack([cv2.resize(image, sz) for image in images])
46
+ images = images.transpose(0,3,1,2).astype(np.float32)
47
+ images = (images / 255.0 - imagenet_means) / imagenet_stds
48
+
49
+ for b in range(0, len(images), bs):
50
+ ort_inputs = {ort_session.get_inputs()[0].name: images[b:b+bs]}
51
+ preds = ort_session.run(None, ort_inputs)[0]
52
+ all_preds = preds if b==0 else np.concatenate((all_preds, preds))
53
+ confidences = _softmax(all_preds)[:,1] # class 0 is "Bare", class 1 is "Shod"
54
+
55
+ return confidences
56
+
57
+
58
+ css = """
59
+ h1 {
60
+ text-align: center;
61
+ display:block;
62
+ vertical-align: middle;
63
+ }
64
+ #title-column {
65
+ padding: 0px !important; /* Remove padding from the parent column */
66
+ gap: 0px !important; /* Ensure gap is zero */
67
+ }
68
+ #title-and-subtitle {
69
+ margin: 0px !important;
70
+ padding: 0px !important;
71
+ }
72
+ .logo {
73
+ max-height: 128px;
74
+ display: inline-block;
75
+ vertical-align: middle;
76
+ }
77
+ """
78
+
79
+ with gr.Blocks(css=css) as app:
80
+ with gr.Column():
81
+ with gr.Row():
82
+ gr.Image(
83
+ value="paleostep-logo-cropped-128.png",
84
+ interactive=False,
85
+ show_label=False,
86
+ show_download_button=False,
87
+ show_share_button=False,
88
+ container=False,
89
+ show_fullscreen_button=False,
90
+ elem_id="logo",
91
+ )
92
+ with gr.Row():
93
+ with gr.Column(elem_id="title-column"):
94
+ gr.Markdown("""
95
+ # STEP: Shod Track Estimated Percentage
96
+ <p style='color: gray; text-align: center; font-style: italic; margin: 0; padding: 0;'>Mysteriously Accurate Rim Curvature INdex</p>
97
+ """, elem_id="title-and-subtitle")
98
+
99
+ #################################################################################
100
+ with gr.Tab('Single outline classification'):
101
+ with gr.Row():
102
+ gr_input = gr.File(file_types=['.csv', '.xlsx', '.json'], file_count="single", label="Upload Outline File")
103
+
104
+ with gr.Row():
105
+ gr.Label(value="Upload a .csv/.xlsx/.json file", visible=True, show_label=False)
106
+
107
+ with gr.Row():
108
+ gr_plot = gr.Plot(label="Outline Plot", show_label=True, visible=False)
109
+
110
+ with gr.Row():
111
+ gr_label = gr.Label(label="Classification", visible=False, show_label=False)
112
+
113
+ def _classify_image(csv_file):
114
+ try:
115
+ image, fig, fname = csv2image_fig(csv_file)
116
+ if len(image.shape) == 2: image = np.tile(image[...,None],(1,1,3))
117
+ confidence = get_predictions([image]).item()
118
+ classification = "Shoe" if confidence >= 0.5 else "No shoe"
119
+ return (
120
+ classification, {f"Shoe confidence: {100*confidence:.1f}": confidence}, gr.update(visible=True), # gr_label
121
+ fig, gr.update(visible=True, label=fname) # gr_plot
122
+ )
123
+ except Exception as e:
124
+ return str(e), str(e), gr.update(visible=True), None, gr.update(visible=False)
125
+
126
+ gr_input.upload(
127
+ fn=_classify_image,
128
+ inputs=[gr_input],
129
+ outputs=[gr_label, gr_label, gr_label, gr_plot, gr_plot],
130
+ )
131
+
132
+ gr_input.clear(
133
+ fn=lambda: (*([None]*2), *([gr.update(visible=False)]*2)),
134
+ inputs=[],
135
+ outputs=[gr_label, gr_plot, gr_label, gr_plot],
136
+ )
137
+
138
+
139
+ #################################################################################
140
+ with gr.Tab('Batch classification'):
141
+ with gr.Row():
142
+ gr_input_batch = gr.File(file_types=['.csv', '.xlsx', '.json'], file_count="multiple", label="Upload Outline File(s)")
143
+ with gr.Row():
144
+ gr.Label(value="Upload multiple .csv/.xlsx/.json files.", visible=True, show_label=False)
145
+ with gr.Row(visible=True):
146
+ with gr.Column():
147
+ gr_df = gr.Dataframe(label="Outlines", visible=False, show_label=False, row_count=10)
148
+ gr_results_file = gr.File(visible=False)
149
+
150
+ def _classify_batch(csv_files):
151
+ try:
152
+ for f in glob.glob("classification_results_*.csv"):
153
+ os.remove(f)
154
+
155
+ dfs = sol.csv2dfs(csv_files)
156
+ images = [np.tile(sol.coordsdf2image(df)[...,None],(1,1,3)) for df in dfs]
157
+ confidences = get_predictions(images)
158
+
159
+ out = []
160
+ for df, confidence in zip(dfs,confidences):
161
+ images.append(sol.coordsdf2image(df))
162
+ out.append({
163
+ 'Outline file': df.name,
164
+ 'Points': len(df),
165
+ 'Confidence': 100*confidence
166
+ })
167
+
168
+ df_out = pd.DataFrame(out)
169
+ timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
170
+ filename = f"classification_results_{timestamp}.csv"
171
+ df_out.to_csv(filename, index=False)
172
+
173
+ return df_out.style.format({'Confidence': '{:.1f}%'}), gr.update(visible=True), gr.update(visible=True, value=filename)
174
+
175
+ except Exception as e:
176
+ return pd.DataFrame({'Error': [str(e)]}), gr.update(visible=True), gr.update(visible=False)
177
+
178
+ gr_input_batch.upload(
179
+ fn=_classify_batch,
180
+ inputs=[gr_input_batch],
181
+ outputs=[gr_df, gr_df, gr_results_file],
182
+ )
183
+
184
+ gr_input_batch.clear(
185
+ fn=lambda: (None, *([gr.update(visible=False)]*2)),
186
+ inputs=[],
187
+ outputs=[gr_df, gr_df, gr_results_file],
188
+ )
189
+
190
+
191
+ app.launch(
192
+ server_port=2443,
193
+ share=False,
194
+ debug=False,
195
+ show_api=False
196
+ )