File size: 9,383 Bytes
6785c36
 
 
 
 
 
c639d12
6785c36
 
 
5b8727d
1d2f4d4
 
 
 
 
 
6785c36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c639d12
6785c36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151dbcd
6785c36
 
 
 
 
 
 
 
 
3074e79
 
e8e77d9
3074e79
5b1e399
3074e79
e8e77d9
9a364e1
 
e8e77d9
3074e79
e8e77d9
3074e79
1ed4f67
6785c36
 
 
e8e77d9
9a364e1
6785c36
 
 
 
 
 
 
 
 
1ed4f67
6785c36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb49fb2
6785c36
 
 
a2004e0
 
 
6785c36
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import gradio as gr
import numpy as np
from PIL import Image
import os
from synet.backends import get_backend
from huggingface_hub import hf_hub_download
import spaces

backend = get_backend('ultralytics')

# Make sure the models paths are defined wrt their tasks, it helps map task in backend
model_paths = {
    "Classification": "model/classify/person_classification_flash(448x640).tflite",
    "Detection": "model/detect/person_detection_flash(480x640).tflite",
    "Pose": "model/pose/person_pose_detection_flash(480x640).tflite",
    "Segmentation": "model/segment/person_segmentation_flash(480x640).tflite"
}

selected_model_name = "Classification"
backend.patch(model_paths[selected_model_name])

MODEL_TYPE_LABELS = {
    "Classification": "Person Classification",
    "Detection": "Person Detection",
    "Pose": "Person Pose Detection",
    "Segmentation": "Person Segmentation"
}

# Reverse mapping from display labels to model types
LABEL_TO_MODEL_TYPE = {v: k for k, v in MODEL_TYPE_LABELS.items()}

def get_img_size(model_type):
    return (448, 640) if model_type == "Classification" else (480, 640)

def load_example_image(name):
    path = os.path.join("data", f"{name}.png")
    return Image.open(path)

@spaces.GPU
def gradio_process_image(image, model_label):
    model_name = LABEL_TO_MODEL_TYPE.get(model_label, "Classification")
    print("Model selected:", model_name)
    tflite_path = model_paths[model_name]
    model = backend.get_model(tflite_path, full=True, cache=True)
    imgsz = get_img_size(model_name)
    print("Image size for model:", imgsz)
    result = model(image, imgsz=imgsz)[0]
    if model_name == "Classification":
        result.names = ["No Person", "Person"]
    else:
        result.names = ["Person"]
    return result.plot()

custom_css = """
:root {
    --color-accent: #007dc3;
    --color-primary-500: #007dc3;
    --color-primary-600: #007dc3;
}
body, html, .gradio-container, #root, .main, .app, .gradio-interface, .gradio-block, .gradio-app {
    background: #fff !important;
    max-width: none;
}
.gradio-block, .gradio-app, .gradio-interface, .main, .app {
    background: #fff !important;
    max-width: none;
}
footer, .gradio-footer, .svelte-1ipelgc, .gradio-logo, .gradio-app__settings {
    display: none !important;
}
.main-header { text-align: center; margin: 0 !important; color: #3b82f6 !important; font-weight: 600 !portant; font-size: 2.5rem; letter-spacing: -0.025em; }
.card { background: #fafafa !important; border-radius: 12px !important; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important; border: 1px solid #e5e7eb !important; margin-bottom: 1.5rem !important; transition: all 0.2s ease-in-out !important; overflow: hidden !important; }
.card-header { background: linear-gradient(135deg, #1975cf 0%, #1557b0 100%) !important; color: white !important; padding: 1rem 1.5rem !important; border-radius: 12px 12px 0 0 !important; font-weight: 600 !important; font-size: 1.1rem !important; }
.card-content { padding: 1.5rem !important; color: #4b5563 !important; line-height: 1.6 !important; background: #fafafa !important; }
.btn-example { background: #f1f5f9 !important; border: 1px solid #cbd5e1 !important; color: #4b5563 !important; border-radius: 6px !important; transition: all 0.2s ease !important; margin: 0.35rem !important; padding: 0.5rem 1rem !important; }
.btn-example:hover { background: #1975cf !important; border-color: #1975cf !important; color: white !important; }
.btn-primary { background: #1975cf !important; border-color: #1975cf !important; color: white !important; }
.btn-primary:hover { background: #1557b0 !important; border-color: #1557b0 !important; }
.results-text strong { color: #0056b3 !important; font-weight: bold !important; }
div[data-testid="markdown"] strong { color: #0056b3 !important; font-weight: bold !important; }

/* Dark theme compatibility */
.dark h1#main_title, .dark #main_title h1 {
    color: #4dabf7 !important;
    text-shadow: none rgba(0,0,0,0.5);
}
.dark h3#subtitle, .dark #subtitle h3, .dark [data-testid="markdown"] #subtitle h3 {
    color: #007dc3 !important; 
    text-shadow: none rgba(0,0,0,0.5);
    border: none !important;
}
.dark #subtitle a {
    color: #4dabf7 !important;
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("<h1 style='font-size:2.5em; color:#007dc3; margin-bottom:0; text-shadow: none rgba(255,255,255,0.5);'>SR100 Vision Model Space</h1>", elem_id="main_title")
    gr.Markdown("<h3 style='margin-top:0; color:#007dc3; text-shadow: none; font-weight: 500; border: none;'>Vision models for Person Presence developed by Synaptics for specific to Astra SR100 MCU. Learn more at <a href='https://developer.synaptics.com/docs/sr/sr100/quick-start?utm_source=hf' target='_blank' style='color:#007dc3; text-decoration:underline;'>Synaptics AI Developer Zone</a></h3>", elem_id="subtitle")
    user_text = gr.Markdown("")
    with gr.Row():
        with gr.Column(scale=1):
            model_type_dropdown = gr.Dropdown(
                choices=list(MODEL_TYPE_LABELS.values()),
                value=MODEL_TYPE_LABELS["Classification"],
                label="Select Model Type",
                interactive=True
            )
            default_image = load_example_image("Group")  
            input_image = gr.Image(
                label="",
                height=250, 
                value=default_image,
                interactive=False,
                sources=None,
                show_download_button=False,
                type="numpy"  
            )
            
            classify_btn = gr.Button("Run Model", variant="primary", size="lg", elem_classes=["btn-primary"])
            with gr.Group(elem_classes=["card"]):
                gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Example Images</span></div>')
                with gr.Column(elem_classes=["card-content"]):
                    with gr.Row():
                        example_person = gr.Button("Person", size="sm", elem_classes=["btn-example"])
                        example_group = gr.Button("Group", size="sm", elem_classes=["btn-example"])
                    with gr.Row():
                        example_empty = gr.Button("Room", size="sm", elem_classes=["btn-example"])
                        example_person2 = gr.Button("Person2", size="sm", elem_classes=["btn-example"])

        with gr.Column(scale=1):
            with gr.Group(elem_classes=["card"]):
                gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Results</span></div>')
                with gr.Column(elem_classes=["card-content"]):
                    output_markdown = gr.Markdown(
                        value="<span style='color: black;'>Select an image to see predictions.</span>",
                        label="",
                        elem_classes=["results-text"],
                        visible=True
                    )
                    output_image = gr.Image(
                        label="Output Image",
                        visible=False,
                        height=520,  
                        width=800    
                    )

    def process_and_update(image, model_label):
        result_img = gradio_process_image(image, model_label)
        return gr.update(visible=False), gr.update(visible=True), result_img

    classify_btn.click(
        fn=process_and_update,
        inputs=[input_image, model_type_dropdown],
        outputs=[output_markdown, output_image, output_image]
    )
    
    example_person.click(lambda: load_example_image("Person"), outputs=input_image)
    example_group.click(lambda: load_example_image("Group"), outputs=input_image)
    example_empty.click(lambda: load_example_image("Room"), outputs=input_image)
    example_person2.click(lambda: load_example_image("Person2"), outputs=input_image)

    input_image.change(
        fn=process_and_update,
        inputs=[input_image, model_type_dropdown],
        outputs=[output_markdown, output_image, output_image]
    )
    
    # Footer
    gr.HTML("""
    <div style="max-width: 900px; margin: 2rem auto; background: white; color: black; border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); border: 1px solid #e5e7eb; padding: 1.5rem; text-align: center;">
        These are all TFLite models developed by Synaptics which can be compiled for Synaptics Astra SR100 MCU.<br>
        For a detailed walkthrough, please see our
        <a href="https://developer.synaptics.com/docs/sr/sr100/evaluate-sr?utm_source=hf" target="_blank" style="color: #1a0dab;">Evaluate Model Guide</a>.<br>
        This Space uses a simulation toolchain to estimate model performance providing results that closely reflect real hardware behavior.
        <br><br>
        Request a 
        <a href="https://synacsm.atlassian.net/servicedesk/customer/portal/543/group/597/create/7208?utm_source=hf" target="_blank" style="color: #1a0dab;">Machina Micro [MCU] Dev Kit</a> with Astra SR100 MCU.
    </div>
    """)
    
    def on_load():
        if default_image is not None:
            return gradio_process_image(default_image, MODEL_TYPE_LABELS["Classification"])
        return None
    
    demo.load(fn=on_load, outputs=output_image)

if __name__ == "__main__":
    demo.launch()