dgarrett-synaptics commited on
Commit
5dc1809
·
verified ·
1 Parent(s): dd6740a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import glob
2
  import gradio as gr
 
 
3
  import sr100_model_compiler
4
  from huggingface_hub import HfApi
5
  from huggingface_hub import whoami
@@ -31,19 +33,21 @@ def compile_model(model_name, sram_size, tensor_size, optimize, model_loc, clock
31
  if oauth_info['token'] is None:
32
  return "ERROR - please log into HuggingFace to continue"
33
 
34
- # Run the comparison
35
- out_dir = './tmp'
36
-
37
- # Run the model fitter
38
- results = sr100_model_compiler.sr100_model_compiler(
39
- model_file=model_name,
40
- output_dir=f"{out_dir}",
41
- model_loc=model_loc,
42
- optimize=optimize,
43
- arena_cache_size=int(float(tensor_size)*1.0e6)
44
- )
45
- print(results)
46
-
 
 
47
  default_config = sr100_model_compiler.sr100_default_config()
48
 
49
  default_config['sram_size'] = int(float(sram_size)*1.0e6)
@@ -67,7 +71,7 @@ def compile_model(model_name, sram_size, tensor_size, optimize, model_loc, clock
67
  # print(f'ERROR converting {epochs_text}, and {batch_text}')
68
  return output_text
69
 
70
-
71
  model_choices = glob.glob('models/*.tflite')
72
 
73
  def update_sliders(sram_slider_value, tensor_slider_value):
 
1
  import glob
2
  import gradio as gr
3
+ import tempfile
4
+ import os
5
  import sr100_model_compiler
6
  from huggingface_hub import HfApi
7
  from huggingface_hub import whoami
 
33
  if oauth_info['token'] is None:
34
  return "ERROR - please log into HuggingFace to continue"
35
 
36
+ # Create a temporary directory
37
+ with tempfile.TemporaryDirectory() as out_dir:
38
+ print(f"Created temporary directory: {out_dir}")
39
+
40
+ # Run the model fitter
41
+ results = sr100_model_compiler.sr100_model_compiler(
42
+ model_file=model_name,
43
+ output_dir=f"{out_dir}",
44
+ model_loc=model_loc,
45
+ optimize=optimize,
46
+ arena_cache_size=int(float(tensor_size)*1.0e6)
47
+ )
48
+ print(results)
49
+
50
+ # Analyze the model
51
  default_config = sr100_model_compiler.sr100_default_config()
52
 
53
  default_config['sram_size'] = int(float(sram_size)*1.0e6)
 
71
  # print(f'ERROR converting {epochs_text}, and {batch_text}')
72
  return output_text
73
 
74
+ # Get all available models
75
  model_choices = glob.glob('models/*.tflite')
76
 
77
  def update_sliders(sram_slider_value, tensor_slider_value):