Aditya Sahu commited on
Commit
ef15281
·
verified ·
1 Parent(s): e8828b9

Add validation for tflite models uploads

Browse files
Files changed (1) hide show
  1. app.py +46 -13
app.py CHANGED
@@ -4,11 +4,47 @@ import tempfile
4
  import os
5
  import sr100_model_compiler
6
  import html
 
7
 
8
- def compile_model(model_name, vmem_value, lpmem_value):
9
 
10
- #if oauth_info['token'] is None:
11
- # return "ERROR - please log into HuggingFace to continue"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Create a temporary directory
14
  with tempfile.TemporaryDirectory() as out_dir:
@@ -19,13 +55,12 @@ def compile_model(model_name, vmem_value, lpmem_value):
19
 
20
  # Run the model fitter
21
  success, results = sr100_model_compiler.sr100_model_optimizer(
22
- model_file=model_name,
23
  vmem_size_limit=vmem_size_limit,
24
  lpmem_size_limit=lpmem_size_limit
25
  )
26
  print(results)
27
 
28
- # Format results in nicely styled HTML like in old.py
29
  output = []
30
 
31
  if results['cycles_npu'] == 0:
@@ -97,7 +132,6 @@ def compile_model(model_name, vmem_value, lpmem_value):
97
  # Get all available models
98
  model_choices = glob.glob('models/*.tflite')
99
 
100
- # Custom CSS from old.py
101
  custom_css = """
102
  :root {
103
  --color-accent: #007dc3;
@@ -140,32 +174,31 @@ footer, .gradio-footer, .svelte-1ipelgc, .gradio-logo, .gradio-app__settings {
140
  """
141
 
142
  with gr.Blocks(css=custom_css) as demo:
143
- #gr.LoginButton()
144
  gr.Markdown("<h1 style='font-size:2.5em; color:#007dc3; margin-bottom:0;'>SR100 Model Compiler</h1>", elem_id="main_title")
145
  gr.Markdown("<h3 style='margin-top:0; color:#000;'>Bring a TFlite INT8 model and compile it for Synaptics Astra SR100. Learn more at <a href='https://developer.synaptics.com/docs/sr/sr100/quick-start?utm_source=hf' target='_blank' style='color:#007dc3; text-decoration:underline;'>Synaptics AI Developer Zone</a></h3>", elem_id="subtitle")
146
- #user_text = gr.Markdown("")
147
 
148
  # Setup model inputs
149
  with gr.Row():
150
  vmem_slider = gr.Slider(minimum=0, maximum=1536, step=1.024, label="Set total VMEM SRAM size available in kB", value=1536.0)
151
  lpmem_slider = gr.Slider(minimum=0, maximum=1536, step=1.024, label="Set total LPMEM SRAM size in kB", value=1536.0)
152
 
153
- # Setup model compile
154
  model_dropdown = gr.Dropdown(
155
- label="Select an model",
156
  value='models/hello_world.tflite',
157
  choices=model_choices
158
  )
159
 
 
 
 
160
  # Run the compile
161
  compile_btn = gr.Button("Compile Model")
162
  compile_text = gr.Markdown("<span style='color:#000;'>Waiting for model results</span>")
163
 
164
  # Compute options
165
- compile_btn.click(compile_model, inputs=[model_dropdown, vmem_slider, lpmem_slider], outputs=[compile_text])
166
- #demo.load(get_oauth_info, inputs=None, outputs=user_text)
167
 
168
- # Add footer content from old.py
169
  gr.HTML("""
170
  <div style="max-width: 800px; margin: 2rem auto; background: white; color: black; border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); border: 1px solid #e5e7eb; padding: 1.5rem; text-align: center;">
171
  For a detailed walkthrough, please see our
 
4
  import os
5
  import sr100_model_compiler
6
  import html
7
+ import pathlib
8
 
9
+ # ---------- Helpers ----------
10
 
11
+ def _resolve_uploaded_path(uploaded):
12
+ """
13
+ Normalize Gradio File input into a filesystem path.
14
+ Handles: str, dict with {path|name}, file-like objects with .path/.name,
15
+ or a list/tuple of the above.
16
+ """
17
+ if uploaded is None:
18
+ return None
19
+ if isinstance(uploaded, (list, tuple)) and uploaded:
20
+ return _resolve_uploaded_path(uploaded[0])
21
+ if isinstance(uploaded, str):
22
+ return uploaded
23
+ if isinstance(uploaded, dict):
24
+ return uploaded.get("path") or uploaded.get("name")
25
+ for attr in ("path", "name"):
26
+ if hasattr(uploaded, attr):
27
+ return getattr(uploaded, attr)
28
+ return None
29
+
30
+ def compile_model(model_name, vmem_value, lpmem_value, uploaded_model):
31
+ # Decide the source model path (uploaded has priority)
32
+ uploaded_path = _resolve_uploaded_path(uploaded_model)
33
+ model_path = uploaded_path or model_name
34
+
35
+ # Basic validations
36
+ if not model_path or not os.path.exists(model_path):
37
+ return (
38
+ "<div style='color:#d32f2f; font-weight:bold; font-size:1.1em;'>"
39
+ "❌ ERROR: Could not locate the model file you selected or uploaded."
40
+ "</div>"
41
+ )
42
+
43
+ if pathlib.Path(model_path).suffix.lower() != ".tflite":
44
+ return (
45
+ "<div style='color:#d32f2f; font-weight:bold; font-size:1.1em;'>"
46
+ "❌ ERROR: Please provide a <code>.tflite</code> model file.</div>"
47
+ )
48
 
49
  # Create a temporary directory
50
  with tempfile.TemporaryDirectory() as out_dir:
 
55
 
56
  # Run the model fitter
57
  success, results = sr100_model_compiler.sr100_model_optimizer(
58
+ model_file=model_path,
59
  vmem_size_limit=vmem_size_limit,
60
  lpmem_size_limit=lpmem_size_limit
61
  )
62
  print(results)
63
 
 
64
  output = []
65
 
66
  if results['cycles_npu'] == 0:
 
132
  # Get all available models
133
  model_choices = glob.glob('models/*.tflite')
134
 
 
135
  custom_css = """
136
  :root {
137
  --color-accent: #007dc3;
 
174
  """
175
 
176
  with gr.Blocks(css=custom_css) as demo:
 
177
  gr.Markdown("<h1 style='font-size:2.5em; color:#007dc3; margin-bottom:0;'>SR100 Model Compiler</h1>", elem_id="main_title")
178
  gr.Markdown("<h3 style='margin-top:0; color:#000;'>Bring a TFlite INT8 model and compile it for Synaptics Astra SR100. Learn more at <a href='https://developer.synaptics.com/docs/sr/sr100/quick-start?utm_source=hf' target='_blank' style='color:#007dc3; text-decoration:underline;'>Synaptics AI Developer Zone</a></h3>", elem_id="subtitle")
 
179
 
180
  # Setup model inputs
181
  with gr.Row():
182
  vmem_slider = gr.Slider(minimum=0, maximum=1536, step=1.024, label="Set total VMEM SRAM size available in kB", value=1536.0)
183
  lpmem_slider = gr.Slider(minimum=0, maximum=1536, step=1.024, label="Set total LPMEM SRAM size in kB", value=1536.0)
184
 
185
+ # Setup model selection/upload
186
  model_dropdown = gr.Dropdown(
187
+ label="Select a model",
188
  value='models/hello_world.tflite',
189
  choices=model_choices
190
  )
191
 
192
+ # Add file upload component
193
+ model_upload = gr.File(label="Or upload a .tflite INT8 model", file_types=[".tflite"], file_count="single")
194
+
195
  # Run the compile
196
  compile_btn = gr.Button("Compile Model")
197
  compile_text = gr.Markdown("<span style='color:#000;'>Waiting for model results</span>")
198
 
199
  # Compute options
200
+ compile_btn.click(compile_model, inputs=[model_dropdown, vmem_slider, lpmem_slider, model_upload], outputs=[compile_text])
 
201
 
 
202
  gr.HTML("""
203
  <div style="max-width: 800px; margin: 2rem auto; background: white; color: black; border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); border: 1px solid #e5e7eb; padding: 1.5rem; text-align: center;">
204
  For a detailed walkthrough, please see our