Aditya Sahu
Update app.py
bb49fb2 verified
import gradio as gr
import numpy as np
from PIL import Image
import os
from synet.backends import get_backend
from huggingface_hub import hf_hub_download
import spaces
backend = get_backend('ultralytics')
# Make sure the models paths are defined wrt their tasks, it helps map task in backend
model_paths = {
"Classification": "model/classify/person_classification_flash(448x640).tflite",
"Detection": "model/detect/person_detection_flash(480x640).tflite",
"Pose": "model/pose/person_pose_detection_flash(480x640).tflite",
"Segmentation": "model/segment/person_segmentation_flash(480x640).tflite"
}
selected_model_name = "Classification"
backend.patch(model_paths[selected_model_name])
MODEL_TYPE_LABELS = {
"Classification": "Person Classification",
"Detection": "Person Detection",
"Pose": "Person Pose Detection",
"Segmentation": "Person Segmentation"
}
# Reverse mapping from display labels to model types
LABEL_TO_MODEL_TYPE = {v: k for k, v in MODEL_TYPE_LABELS.items()}
def get_img_size(model_type):
return (448, 640) if model_type == "Classification" else (480, 640)
def load_example_image(name):
path = os.path.join("data", f"{name}.png")
return Image.open(path)
@spaces.GPU
def gradio_process_image(image, model_label):
model_name = LABEL_TO_MODEL_TYPE.get(model_label, "Classification")
print("Model selected:", model_name)
tflite_path = model_paths[model_name]
model = backend.get_model(tflite_path, full=True, cache=True)
imgsz = get_img_size(model_name)
print("Image size for model:", imgsz)
result = model(image, imgsz=imgsz)[0]
if model_name == "Classification":
result.names = ["No Person", "Person"]
else:
result.names = ["Person"]
return result.plot()
custom_css = """
:root {
--color-accent: #007dc3;
--color-primary-500: #007dc3;
--color-primary-600: #007dc3;
}
body, html, .gradio-container, #root, .main, .app, .gradio-interface, .gradio-block, .gradio-app {
background: #fff !important;
max-width: none;
}
.gradio-block, .gradio-app, .gradio-interface, .main, .app {
background: #fff !important;
max-width: none;
}
footer, .gradio-footer, .svelte-1ipelgc, .gradio-logo, .gradio-app__settings {
display: none !important;
}
.main-header { text-align: center; margin: 0 !important; color: #3b82f6 !important; font-weight: 600 !portant; font-size: 2.5rem; letter-spacing: -0.025em; }
.card { background: #fafafa !important; border-radius: 12px !important; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important; border: 1px solid #e5e7eb !important; margin-bottom: 1.5rem !important; transition: all 0.2s ease-in-out !important; overflow: hidden !important; }
.card-header { background: linear-gradient(135deg, #1975cf 0%, #1557b0 100%) !important; color: white !important; padding: 1rem 1.5rem !important; border-radius: 12px 12px 0 0 !important; font-weight: 600 !important; font-size: 1.1rem !important; }
.card-content { padding: 1.5rem !important; color: #4b5563 !important; line-height: 1.6 !important; background: #fafafa !important; }
.btn-example { background: #f1f5f9 !important; border: 1px solid #cbd5e1 !important; color: #4b5563 !important; border-radius: 6px !important; transition: all 0.2s ease !important; margin: 0.35rem !important; padding: 0.5rem 1rem !important; }
.btn-example:hover { background: #1975cf !important; border-color: #1975cf !important; color: white !important; }
.btn-primary { background: #1975cf !important; border-color: #1975cf !important; color: white !important; }
.btn-primary:hover { background: #1557b0 !important; border-color: #1557b0 !important; }
.results-text strong { color: #0056b3 !important; font-weight: bold !important; }
div[data-testid="markdown"] strong { color: #0056b3 !important; font-weight: bold !important; }
/* Dark theme compatibility */
.dark h1#main_title, .dark #main_title h1 {
color: #4dabf7 !important;
text-shadow: none rgba(0,0,0,0.5);
}
.dark h3#subtitle, .dark #subtitle h3, .dark [data-testid="markdown"] #subtitle h3 {
color: #007dc3 !important;
text-shadow: none rgba(0,0,0,0.5);
border: none !important;
}
.dark #subtitle a {
color: #4dabf7 !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("<h1 style='font-size:2.5em; color:#007dc3; margin-bottom:0; text-shadow: none rgba(255,255,255,0.5);'>SR100 Vision Model Space</h1>", elem_id="main_title")
gr.Markdown("<h3 style='margin-top:0; color:#007dc3; text-shadow: none; font-weight: 500; border: none;'>Vision models for Person Presence developed by Synaptics for specific to Astra SR100 MCU. Learn more at <a href='https://developer.synaptics.com/docs/sr/sr100/quick-start?utm_source=hf' target='_blank' style='color:#007dc3; text-decoration:underline;'>Synaptics AI Developer Zone</a></h3>", elem_id="subtitle")
user_text = gr.Markdown("")
with gr.Row():
with gr.Column(scale=1):
model_type_dropdown = gr.Dropdown(
choices=list(MODEL_TYPE_LABELS.values()),
value=MODEL_TYPE_LABELS["Classification"],
label="Select Model Type",
interactive=True
)
default_image = load_example_image("Group")
input_image = gr.Image(
label="",
height=250,
value=default_image,
interactive=False,
sources=None,
show_download_button=False,
type="numpy"
)
classify_btn = gr.Button("Run Model", variant="primary", size="lg", elem_classes=["btn-primary"])
with gr.Group(elem_classes=["card"]):
gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Example Images</span></div>')
with gr.Column(elem_classes=["card-content"]):
with gr.Row():
example_person = gr.Button("Person", size="sm", elem_classes=["btn-example"])
example_group = gr.Button("Group", size="sm", elem_classes=["btn-example"])
with gr.Row():
example_empty = gr.Button("Room", size="sm", elem_classes=["btn-example"])
example_person2 = gr.Button("Person2", size="sm", elem_classes=["btn-example"])
with gr.Column(scale=1):
with gr.Group(elem_classes=["card"]):
gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Results</span></div>')
with gr.Column(elem_classes=["card-content"]):
output_markdown = gr.Markdown(
value="<span style='color: black;'>Select an image to see predictions.</span>",
label="",
elem_classes=["results-text"],
visible=True
)
output_image = gr.Image(
label="Output Image",
visible=False,
height=520,
width=800
)
def process_and_update(image, model_label):
result_img = gradio_process_image(image, model_label)
return gr.update(visible=False), gr.update(visible=True), result_img
classify_btn.click(
fn=process_and_update,
inputs=[input_image, model_type_dropdown],
outputs=[output_markdown, output_image, output_image]
)
example_person.click(lambda: load_example_image("Person"), outputs=input_image)
example_group.click(lambda: load_example_image("Group"), outputs=input_image)
example_empty.click(lambda: load_example_image("Room"), outputs=input_image)
example_person2.click(lambda: load_example_image("Person2"), outputs=input_image)
input_image.change(
fn=process_and_update,
inputs=[input_image, model_type_dropdown],
outputs=[output_markdown, output_image, output_image]
)
# Footer
gr.HTML("""
<div style="max-width: 900px; margin: 2rem auto; background: white; color: black; border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); border: 1px solid #e5e7eb; padding: 1.5rem; text-align: center;">
These are all TFLite models developed by Synaptics which can be compiled for Synaptics Astra SR100 MCU.<br>
For a detailed walkthrough, please see our
<a href="https://developer.synaptics.com/docs/sr/sr100/evaluate-sr?utm_source=hf" target="_blank" style="color: #1a0dab;">Evaluate Model Guide</a>.<br>
This Space uses a simulation toolchain to estimate model performance providing results that closely reflect real hardware behavior.
<br><br>
Request a
<a href="https://synacsm.atlassian.net/servicedesk/customer/portal/543/group/597/create/7208?utm_source=hf" target="_blank" style="color: #1a0dab;">Machina Micro [MCU] Dev Kit</a> with Astra SR100 MCU.
</div>
""")
def on_load():
if default_image is not None:
return gradio_process_image(default_image, MODEL_TYPE_LABELS["Classification"])
return None
demo.load(fn=on_load, outputs=output_image)
if __name__ == "__main__":
demo.launch()