|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import glob, os |
|
|
import shoe_outlines_lib as sol |
|
|
import matplotlib.pyplot as plt |
|
|
import onnxruntime |
|
|
import cv2 |
|
|
|
|
|
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
imagenet_means = np.array(imagenet_stats[0], dtype=np.float32)[:, None, None] |
|
|
imagenet_stds = np.array(imagenet_stats[1], dtype=np.float32)[:, None, None] |
|
|
sz = (160, 256) |
|
|
|
|
|
|
|
|
ort_session = onnxruntime.InferenceSession('shod-model.onnx') |
|
|
|
|
|
|
|
|
def csv2image_fig(csv_file): |
|
|
df = sol.csv2dfs([csv_file])[0] |
|
|
fname = df.name |
|
|
df = pd.concat([df, df.iloc[[0]]], ignore_index=True) |
|
|
df = sol.norm_by_x(df) |
|
|
image = sol.coordsdf2image(df) |
|
|
fig = plt.figure(figsize=(2, 4)) |
|
|
plt.plot(df['x'], df['y'], marker='', linestyle='-', color='b', label='Line') |
|
|
plt.fill(df['x'], df['y'], color='blue', alpha=0.2) |
|
|
plt.axis('equal') |
|
|
plt.axis('off') |
|
|
plt.gca().invert_yaxis() |
|
|
return image, fig, fname |
|
|
|
|
|
|
|
|
def get_predictions(images, bs=8): |
|
|
''' class 0 is "No shoe", class 1 is "Shoe" ''' |
|
|
|
|
|
def _softmax(logits): |
|
|
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) |
|
|
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) |
|
|
|
|
|
if isinstance(images, np.ndarray): images = [images] |
|
|
|
|
|
images = np.stack([cv2.resize(image, sz) for image in images]) |
|
|
images = images.transpose(0,3,1,2).astype(np.float32) |
|
|
images = (images / 255.0 - imagenet_means) / imagenet_stds |
|
|
|
|
|
for b in range(0, len(images), bs): |
|
|
ort_inputs = {ort_session.get_inputs()[0].name: images[b:b+bs]} |
|
|
preds = ort_session.run(None, ort_inputs)[0] |
|
|
all_preds = preds if b==0 else np.concatenate((all_preds, preds)) |
|
|
confidences = _softmax(all_preds)[:,1] |
|
|
|
|
|
return confidences |
|
|
|
|
|
|
|
|
css = """ |
|
|
h1 { |
|
|
text-align: center; |
|
|
display:block; |
|
|
vertical-align: middle; |
|
|
} |
|
|
#title-column { |
|
|
padding: 0px !important; /* Remove padding from the parent column */ |
|
|
gap: 0px !important; /* Ensure gap is zero */ |
|
|
} |
|
|
#title-and-subtitle { |
|
|
margin: 0px !important; |
|
|
padding: 0px !important; |
|
|
} |
|
|
#title-and-subtitle .prose h1 { |
|
|
margin: 0px !important; |
|
|
padding: 0px !important; |
|
|
} |
|
|
#title-and-subtitle .prose p { |
|
|
margin: 0px !important; |
|
|
padding: 0px !important; |
|
|
color: gray !important; /* Added */ |
|
|
text-align: center !important; /* Added */ |
|
|
font-style: italic !important; /* Added */ |
|
|
} |
|
|
.logo { |
|
|
max-height: 128px; |
|
|
display: inline-block; |
|
|
vertical-align: middle; |
|
|
} |
|
|
.gradio-container { |
|
|
width: 1200px !important; /* Use !important to override defaults if needed */ |
|
|
margin: 0 auto; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as app: |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
gr.Image( |
|
|
value="paleostep-logo-cropped-128.png", |
|
|
interactive=False, |
|
|
show_label=False, |
|
|
show_download_button=False, |
|
|
show_share_button=False, |
|
|
container=False, |
|
|
show_fullscreen_button=False, |
|
|
elem_id="logo", |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(elem_id="title-column"): |
|
|
gr.Markdown(""" |
|
|
# STEP: Shod Track Estimated Percentage |
|
|
<p style='color: gray; text-align: center; font-style: italic; margin: 0; padding: 0;'>Mysteriously Accurate Rim Curvature INdex</p> |
|
|
""", elem_id="title-and-subtitle") |
|
|
|
|
|
|
|
|
with gr.Tab('Single outline classification'): |
|
|
with gr.Row(): |
|
|
gr_input = gr.File(file_types=['.csv', '.xlsx', '.json'], file_count="single", label="Upload Outline File") |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Label(value="Upload a .csv/.xlsx/.json file", visible=True, show_label=False) |
|
|
|
|
|
with gr.Row(): |
|
|
gr_plot = gr.Plot(label="Outline Plot", show_label=True, visible=False) |
|
|
|
|
|
with gr.Row(): |
|
|
gr_label = gr.Label(label="Classification", visible=False, show_label=False) |
|
|
|
|
|
def _classify_image(csv_file): |
|
|
try: |
|
|
image, fig, fname = csv2image_fig(csv_file) |
|
|
if len(image.shape) == 2: image = np.tile(image[...,None],(1,1,3)) |
|
|
confidence = get_predictions([image]).item() |
|
|
classification = "Shoe" if confidence >= 0.5 else "No shoe" |
|
|
return ( |
|
|
classification, {f"Shoe confidence: {100*confidence:.1f}": confidence}, gr.update(visible=True), |
|
|
fig, gr.update(visible=True, label=fname) |
|
|
) |
|
|
except Exception as e: |
|
|
return str(e), str(e), gr.update(visible=True), None, gr.update(visible=False) |
|
|
|
|
|
gr_input.upload( |
|
|
fn=_classify_image, |
|
|
inputs=[gr_input], |
|
|
outputs=[gr_label, gr_label, gr_label, gr_plot, gr_plot], |
|
|
) |
|
|
|
|
|
gr_input.clear( |
|
|
fn=lambda: (*([None]*2), *([gr.update(visible=False)]*2)), |
|
|
inputs=[], |
|
|
outputs=[gr_label, gr_plot, gr_label, gr_plot], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab('Batch classification'): |
|
|
with gr.Row(): |
|
|
gr_input_batch = gr.File(file_types=['.csv', '.xlsx', '.json'], file_count="multiple", label="Upload Outline File(s)") |
|
|
with gr.Row(): |
|
|
gr.Label(value="Upload multiple .csv/.xlsx/.json files.", visible=True, show_label=False) |
|
|
with gr.Row(visible=True): |
|
|
with gr.Column(): |
|
|
gr_df = gr.Dataframe(label="Outlines", visible=False, show_label=False, row_count=10) |
|
|
gr_results_file = gr.File(visible=False) |
|
|
|
|
|
def _classify_batch(csv_files): |
|
|
try: |
|
|
for f in glob.glob("classification_results_*.csv"): |
|
|
os.remove(f) |
|
|
|
|
|
dfs = sol.csv2dfs(csv_files) |
|
|
images = [np.tile(sol.coordsdf2image(df)[...,None],(1,1,3)) for df in dfs] |
|
|
confidences = get_predictions(images) |
|
|
|
|
|
out = [] |
|
|
for df, confidence in zip(dfs,confidences): |
|
|
images.append(sol.coordsdf2image(df)) |
|
|
out.append({ |
|
|
'Outline file': df.name, |
|
|
'Points': len(df), |
|
|
'Confidence': 100*confidence |
|
|
}) |
|
|
|
|
|
df_out = pd.DataFrame(out) |
|
|
timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S') |
|
|
filename = f"classification_results_{timestamp}.csv" |
|
|
df_out.to_csv(filename, index=False) |
|
|
|
|
|
return df_out.style.format({'Confidence': '{:.1f}%'}), gr.update(visible=True), gr.update(visible=True, value=filename) |
|
|
|
|
|
except Exception as e: |
|
|
return pd.DataFrame({'Error': [str(e)]}), gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
gr_input_batch.upload( |
|
|
fn=_classify_batch, |
|
|
inputs=[gr_input_batch], |
|
|
outputs=[gr_df, gr_df, gr_results_file], |
|
|
) |
|
|
|
|
|
gr_input_batch.clear( |
|
|
fn=lambda: (None, *([gr.update(visible=False)]*2)), |
|
|
inputs=[], |
|
|
outputs=[gr_df, gr_df, gr_results_file], |
|
|
) |
|
|
|
|
|
|
|
|
app.launch( |
|
|
share=False, |
|
|
debug=False, |
|
|
show_api=False |
|
|
) |
|
|
|