|
|
from huggingface_hub import hf_hub_download |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import ai_edge_litert as litert |
|
|
|
|
|
|
|
|
repo_id = "sammlapp/perch2-tflite" |
|
|
model_path = hf_hub_download(repo_id=repo_id, filename="Perch2.tflite") |
|
|
|
|
|
|
|
|
labels_path = hf_hub_download(repo_id=repo_id, filename="perch2_class_labels.txt") |
|
|
labels = pd.read_csv(labels_path, header=None).iloc[:, 0].tolist() |
|
|
|
|
|
print("Model:", model_path) |
|
|
print("Labels:", len(labels)) |
|
|
|
|
|
|
|
|
interpreter = litert.interpreter.Interpreter(model_path=model_path) |
|
|
interpreter.allocate_tensors() |
|
|
|
|
|
sig = interpreter.get_signature_runner("serving_default") |
|
|
|
|
|
input_details = interpreter.get_input_details() |
|
|
signature_list = interpreter.get_signature_list() |
|
|
|
|
|
input_name = signature_list["serving_default"]["inputs"][0] |
|
|
output_names = signature_list["serving_default"]["outputs"] |
|
|
input_shape = input_details[0]["shape"] |
|
|
input_dtype = input_details[0]["dtype"] |
|
|
|
|
|
sample_input = np.random.uniform(-1, 1, size=(10, 160000)).astype(input_dtype) |
|
|
|
|
|
|
|
|
result = sig(**{input_name: sample_input}) |
|
|
|
|
|
|
|
|
print("Output keys:", result.keys()) |
|
|
for k, v in result.items(): |
|
|
print(k, v.shape) |