sammlapp commited on
Commit
ea18f66
·
verified ·
1 Parent(s): d28a72a

Create sample_inference_script.py

Browse files
Files changed (1) hide show
  1. sample_inference_script.py +39 -0
sample_inference_script.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import pandas as pd
3
+ import numpy as np
4
+ import ai_edge_litert as litert
5
+
6
+ # Download model
7
+ repo_id = "sammlapp/perch2-tflite"
8
+ model_path = hf_hub_download(repo_id=repo_id, filename="Perch2.tflite")
9
+
10
+ # Download and read labels
11
+ labels_path = hf_hub_download(repo_id=repo_id, filename="perch2_class_labels.txt")
12
+ labels = pd.read_csv(labels_path, header=None).iloc[:, 0].tolist()
13
+
14
+ print("Model:", model_path)
15
+ print("Labels:", len(labels))
16
+
17
+ # path = "./Perch2.tflite"
18
+ interpreter = litert.interpreter.Interpreter(model_path=model_path)
19
+ interpreter.allocate_tensors()
20
+
21
+ sig = interpreter.get_signature_runner("serving_default")
22
+
23
+ input_details = interpreter.get_input_details()
24
+ signature_list = interpreter.get_signature_list()
25
+
26
+ input_name = signature_list["serving_default"]["inputs"][0]
27
+ output_names = signature_list["serving_default"]["outputs"]
28
+ input_shape = input_details[0]["shape"]
29
+ input_dtype = input_details[0]["dtype"]
30
+
31
+ sample_input = np.random.uniform(-1, 1, size=(10, 160000)).astype(input_dtype)
32
+
33
+ # Run via signature
34
+ result = sig(**{input_name: sample_input})
35
+
36
+ # Each output key in the signature dict will hold a numpy array (safe to access)
37
+ print("Output keys:", result.keys())
38
+ for k, v in result.items():
39
+ print(k, v.shape)