step / app.py
mariboo's picture
Update app.py
733e3de verified
#!/bin/python3
import gradio as gr
import numpy as np
import pandas as pd
import glob, os
import shoe_outlines_lib as sol
import matplotlib.pyplot as plt
import onnxruntime
import cv2
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
imagenet_means = np.array(imagenet_stats[0], dtype=np.float32)[:, None, None]
imagenet_stds = np.array(imagenet_stats[1], dtype=np.float32)[:, None, None]
sz = (160, 256)
# Load the ONNX model
ort_session = onnxruntime.InferenceSession('shod-model.onnx')
def csv2image_fig(csv_file):
df = sol.csv2dfs([csv_file])[0]
fname = df.name
df = pd.concat([df, df.iloc[[0]]], ignore_index=True)
df = sol.norm_by_x(df)
image = sol.coordsdf2image(df)
fig = plt.figure(figsize=(2, 4))
plt.plot(df['x'], df['y'], marker='', linestyle='-', color='b', label='Line')
plt.fill(df['x'], df['y'], color='blue', alpha=0.2)
plt.axis('equal')
plt.axis('off')
plt.gca().invert_yaxis()
return image, fig, fname
def get_predictions(images, bs=8):
''' class 0 is "No shoe", class 1 is "Shoe" '''
def _softmax(logits):
exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) # Stability trick
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
if isinstance(images, np.ndarray): images = [images]
images = np.stack([cv2.resize(image, sz) for image in images])
images = images.transpose(0,3,1,2).astype(np.float32)
images = (images / 255.0 - imagenet_means) / imagenet_stds
for b in range(0, len(images), bs):
ort_inputs = {ort_session.get_inputs()[0].name: images[b:b+bs]}
preds = ort_session.run(None, ort_inputs)[0]
all_preds = preds if b==0 else np.concatenate((all_preds, preds))
confidences = _softmax(all_preds)[:,1] # class 0 is "Bare", class 1 is "Shod"
return confidences
css = """
h1 {
text-align: center;
display:block;
vertical-align: middle;
}
#title-column {
padding: 0px !important; /* Remove padding from the parent column */
gap: 0px !important; /* Ensure gap is zero */
}
#title-and-subtitle {
margin: 0px !important;
padding: 0px !important;
}
#title-and-subtitle .prose h1 {
margin: 0px !important;
padding: 0px !important;
}
#title-and-subtitle .prose p {
margin: 0px !important;
padding: 0px !important;
color: gray !important; /* Added */
text-align: center !important; /* Added */
font-style: italic !important; /* Added */
}
.logo {
max-height: 128px;
display: inline-block;
vertical-align: middle;
}
.gradio-container {
width: 1200px !important; /* Use !important to override defaults if needed */
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as app:
with gr.Column():
with gr.Row():
gr.Image(
value="paleostep-logo-cropped-128.png",
interactive=False,
show_label=False,
show_download_button=False,
show_share_button=False,
container=False,
show_fullscreen_button=False,
elem_id="logo",
)
with gr.Row():
with gr.Column(elem_id="title-column"):
gr.Markdown("""
# STEP: Shod Track Estimated Percentage
<p style='color: gray; text-align: center; font-style: italic; margin: 0; padding: 0;'>Mysteriously Accurate Rim Curvature INdex</p>
""", elem_id="title-and-subtitle")
#################################################################################
with gr.Tab('Single outline classification'):
with gr.Row():
gr_input = gr.File(file_types=['.csv', '.xlsx', '.json'], file_count="single", label="Upload Outline File")
with gr.Row():
gr.Label(value="Upload a .csv/.xlsx/.json file", visible=True, show_label=False)
with gr.Row():
gr_plot = gr.Plot(label="Outline Plot", show_label=True, visible=False)
with gr.Row():
gr_label = gr.Label(label="Classification", visible=False, show_label=False)
def _classify_image(csv_file):
try:
image, fig, fname = csv2image_fig(csv_file)
if len(image.shape) == 2: image = np.tile(image[...,None],(1,1,3))
confidence = get_predictions([image]).item()
classification = "Shoe" if confidence >= 0.5 else "No shoe"
return (
classification, {f"Shoe confidence: {100*confidence:.1f}": confidence}, gr.update(visible=True), # gr_label
fig, gr.update(visible=True, label=fname) # gr_plot
)
except Exception as e:
return str(e), str(e), gr.update(visible=True), None, gr.update(visible=False)
gr_input.upload(
fn=_classify_image,
inputs=[gr_input],
outputs=[gr_label, gr_label, gr_label, gr_plot, gr_plot],
)
gr_input.clear(
fn=lambda: (*([None]*2), *([gr.update(visible=False)]*2)),
inputs=[],
outputs=[gr_label, gr_plot, gr_label, gr_plot],
)
#################################################################################
with gr.Tab('Batch classification'):
with gr.Row():
gr_input_batch = gr.File(file_types=['.csv', '.xlsx', '.json'], file_count="multiple", label="Upload Outline File(s)")
with gr.Row():
gr.Label(value="Upload multiple .csv/.xlsx/.json files.", visible=True, show_label=False)
with gr.Row(visible=True):
with gr.Column():
gr_df = gr.Dataframe(label="Outlines", visible=False, show_label=False, row_count=10)
gr_results_file = gr.File(visible=False)
def _classify_batch(csv_files):
try:
for f in glob.glob("classification_results_*.csv"):
os.remove(f)
dfs = sol.csv2dfs(csv_files)
images = [np.tile(sol.coordsdf2image(df)[...,None],(1,1,3)) for df in dfs]
confidences = get_predictions(images)
out = []
for df, confidence in zip(dfs,confidences):
images.append(sol.coordsdf2image(df))
out.append({
'Outline file': df.name,
'Points': len(df),
'Confidence': 100*confidence
})
df_out = pd.DataFrame(out)
timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
filename = f"classification_results_{timestamp}.csv"
df_out.to_csv(filename, index=False)
return df_out.style.format({'Confidence': '{:.1f}%'}), gr.update(visible=True), gr.update(visible=True, value=filename)
except Exception as e:
return pd.DataFrame({'Error': [str(e)]}), gr.update(visible=True), gr.update(visible=False)
gr_input_batch.upload(
fn=_classify_batch,
inputs=[gr_input_batch],
outputs=[gr_df, gr_df, gr_results_file],
)
gr_input_batch.clear(
fn=lambda: (None, *([gr.update(visible=False)]*2)),
inputs=[],
outputs=[gr_df, gr_df, gr_results_file],
)
app.launch(
share=False,
debug=False,
show_api=False
)