File size: 323 Bytes
ba96580
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from axengine import InferenceSession
import numpy as np


session = InferenceSession("./compiled_slice_quant_onnx/cfg_00_timestep_to_model_t_embedder_mlp_mlp_2_Gemm_output_0_config.axmodel")

input_feed = {
    "timestep": np.array([1.0], dtype=np.float32),
}
output = session.run(None, input_feed)[0]

print(output.shape)