Upload 10 files
Browse files- .gitattributes +3 -0
- RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.emb +3 -0
- RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx +3 -0
- RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.rknn +3 -0
- convert_rknn.py +54 -0
- ea50ffd6-c6fe-11ef-8ff3-1c860b30973e +3 -0
- export_onnx.py +120 -0
- inference.py +210 -0
- rwkv_tokenizer.py +89 -0
- rwkv_vocab_v20230424.txt +0 -0
- ztu_somemodelruntime_rknnlite2.py +509 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ea50ffd6-c6fe-11ef-8ff3-1c860b30973e filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.emb filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.rknn filter=lfs diff=lfs merge=lfs -text
|
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.emb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7bb71268884738ee0bbc62796b838afd9b460da931589151d949e538cbe58255
|
| 3 |
+
size 201326592
|
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:080c9153102fe9c2c54e8245411a9ab70360132a13321c3396dd7cca17eca1c4
|
| 3 |
+
size 305312
|
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.rknn
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17c375a232e19992bba49459fa7a092ecdb6252841b80850095eb5c6fb4e2bf4
|
| 3 |
+
size 289121271
|
convert_rknn.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
import datetime
|
| 5 |
+
from rknn.api import RKNN
|
| 6 |
+
from sys import exit
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
ONNX_MODEL = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx"
|
| 10 |
+
RKNN_MODEL = ONNX_MODEL.replace(".onnx", ".rknn")
|
| 11 |
+
DATASET = ""
|
| 12 |
+
QUANTIZE = False
|
| 13 |
+
detailed_performance_log = True
|
| 14 |
+
|
| 15 |
+
timedate_iso = datetime.datetime.now().isoformat()
|
| 16 |
+
|
| 17 |
+
rknn = RKNN(verbose=True)
|
| 18 |
+
rknn.config(
|
| 19 |
+
# mean_values=[x * 255 for x in [0.485, 0.456, 0.406]],
|
| 20 |
+
# std_values=[x * 255 for x in [0.229, 0.224, 0.225]],
|
| 21 |
+
quantized_dtype="w8a8",
|
| 22 |
+
quantized_algorithm="normal",
|
| 23 |
+
quantized_method="channel",
|
| 24 |
+
quantized_hybrid_level=0,
|
| 25 |
+
target_platform="rk3588",
|
| 26 |
+
quant_img_RGB2BGR=False,
|
| 27 |
+
float_dtype="float16",
|
| 28 |
+
optimization_level=3,
|
| 29 |
+
custom_string=f"converted at {timedate_iso}",
|
| 30 |
+
remove_weight=False,
|
| 31 |
+
compress_weight=False,
|
| 32 |
+
inputs_yuv_fmt=None,
|
| 33 |
+
single_core_mode=False,
|
| 34 |
+
dynamic_input=None,
|
| 35 |
+
model_pruning=False,
|
| 36 |
+
op_target=None,
|
| 37 |
+
quantize_weight=False,
|
| 38 |
+
remove_reshape=False,
|
| 39 |
+
sparse_infer=False,
|
| 40 |
+
enable_flash_attention=False,
|
| 41 |
+
# 隐藏的参数
|
| 42 |
+
# disable_rules=[],
|
| 43 |
+
# sram_prefer=False,
|
| 44 |
+
# nbuf_prefer=False,
|
| 45 |
+
# check_data=[],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
ret = rknn.load_onnx(model=ONNX_MODEL)
|
| 49 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
| 50 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 51 |
+
|
| 52 |
+
# ret = rknn.init_runtime(target='rk3588',device_id='cbb956772bf5dac9',core_mask=RKNN.NPU_CORE_0,perf_debug=detailed_performance_log)
|
| 53 |
+
# rknn.eval_perf()
|
| 54 |
+
# ret = rknn.accuracy_analysis(inputs=['../embeddings.npy','../state.npy','../scale_ratio.npy'], target='rk3588', device_id=device_id)
|
ea50ffd6-c6fe-11ef-8ff3-1c860b30973e
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:faa4dce148b8ed0172ef021b8f732c2eea5dd782caed801dd4727d909d2b9447
|
| 3 |
+
size 562805760
|
export_onnx.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rwkv_src.rwkv_model import RWKV_RNN, make_chunks
|
| 2 |
+
import types
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import copy
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import onnx
|
| 11 |
+
from onnx import shape_inference
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser(description='Convert model')
|
| 14 |
+
parser.add_argument('model', type=Path, help='Path to RWKV pth file')
|
| 15 |
+
parser.add_argument('--chunks', type=int, default=1, help='Number of chunks')
|
| 16 |
+
parser.add_argument('--ext_embedding', action='store_true', default=False, help='Use external embedding')
|
| 17 |
+
parser.add_argument('--prefill_model', action='store_true', help='Convert model for sequential prefill')
|
| 18 |
+
parser.add_argument('--wkv_customop', action='store_true', help='Use custom op for wkv')
|
| 19 |
+
parser_args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
seq_length = 32 if parser_args.prefill_model else 1
|
| 22 |
+
|
| 23 |
+
model_args = types.SimpleNamespace()
|
| 24 |
+
model_args.USE_CUDA = False
|
| 25 |
+
model_args.fp16 = False
|
| 26 |
+
model_args.wkv_customop = parser_args.wkv_customop
|
| 27 |
+
model_args.USE_EMBEDDING = False if parser_args.ext_embedding else True
|
| 28 |
+
|
| 29 |
+
model_args.MODEL_NAME = str(parser_args.model)
|
| 30 |
+
|
| 31 |
+
if 'ABC' in model_args.MODEL_NAME or 'MIDI' in model_args.MODEL_NAME or 'x070' in model_args.MODEL_NAME:
|
| 32 |
+
model_args.RESCALE_LAYER = 0
|
| 33 |
+
else:
|
| 34 |
+
model_args.RESCALE_LAYER = 6
|
| 35 |
+
|
| 36 |
+
model = make_chunks(parser_args.chunks, model_args) if parser_args.chunks > 1 else RWKV_RNN(model_args)
|
| 37 |
+
|
| 38 |
+
if parser_args.prefill_model:
|
| 39 |
+
model_args.MODEL_NAME = model_args.MODEL_NAME + "_prefill"
|
| 40 |
+
|
| 41 |
+
os.path.exists("onnx") or os.mkdir("onnx")
|
| 42 |
+
|
| 43 |
+
if type(model) == list:
|
| 44 |
+
args = model[0].args
|
| 45 |
+
if not args.USE_EMBEDDING:
|
| 46 |
+
model[0].emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk1of{len(model)}.emb")
|
| 47 |
+
args = model[0].args
|
| 48 |
+
fp16 = args.fp16
|
| 49 |
+
states = []
|
| 50 |
+
for i in range(args.n_layer):
|
| 51 |
+
states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
|
| 52 |
+
states.append(torch.zeros(args.n_head, args.head_size, args.head_size, dtype=torch.float16 if fp16 else torch.float32))
|
| 53 |
+
states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
|
| 54 |
+
if model[0].device is not torch.device('cpu'):
|
| 55 |
+
states = [i.to(model[0].device) for i in states]
|
| 56 |
+
|
| 57 |
+
for i in range(len(model)):
|
| 58 |
+
dirname = "onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}"
|
| 59 |
+
os.path.exists(dirname) or os.mkdir(dirname)
|
| 60 |
+
if i == 0 and args.USE_EMBEDDING:
|
| 61 |
+
in0 = torch.LongTensor([[1]*seq_length])
|
| 62 |
+
else:
|
| 63 |
+
in0 = torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)
|
| 64 |
+
|
| 65 |
+
if model[0].device is not torch.device('cpu'):
|
| 66 |
+
in0 = in0.to(model[0].device)
|
| 67 |
+
inputs = {'in0': in0, 'state': [states[j] for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]}
|
| 68 |
+
input_names = ['in'] + [f'state{j}_in' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]
|
| 69 |
+
output_names = ['out'] + [f'state{j}_out' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]
|
| 70 |
+
|
| 71 |
+
if args.wkv_customop:
|
| 72 |
+
from torch.onnx.symbolic_helper import _get_tensor_sizes
|
| 73 |
+
from torch.onnx import register_custom_op_symbolic
|
| 74 |
+
op_name = "rwkv::wkv_chunk" if parser_args.prefill_model else "rwkv::wkv"
|
| 75 |
+
def onnx_custom_wkv(g, k, v, r, state2, time_first, time_decay):
|
| 76 |
+
out1, out2 = g.op(op_name, k, v, r, state2, time_first, time_decay, outputs=2)
|
| 77 |
+
return out1.setType(k.type().with_dtype(torch.float32).with_sizes([seq_length, _get_tensor_sizes(k)[0], 1, args.head_size])),\
|
| 78 |
+
out2.setType(k.type().with_dtype(torch.float32).with_sizes([1, _get_tensor_sizes(k)[0], args.head_size, args.head_size]))
|
| 79 |
+
register_custom_op_symbolic(op_name, onnx_custom_wkv, 9)
|
| 80 |
+
|
| 81 |
+
torch.onnx.export(model[i], inputs, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", input_names=input_names, output_names=output_names, opset_version=17)
|
| 82 |
+
shape_inference.infer_shapes_path(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
|
| 83 |
+
onnx_model = onnx.load(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
|
| 84 |
+
|
| 85 |
+
# To make model compatible with other frameworks
|
| 86 |
+
for initializer in onnx_model.graph.initializer:
|
| 87 |
+
shape = list(initializer.dims)
|
| 88 |
+
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
|
| 89 |
+
onnx_model.graph.value_info.append(value_info)
|
| 90 |
+
onnx.save_model(onnx_model, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", save_as_external_data=True, all_tensors_to_one_file=True)
|
| 91 |
+
print(f"onnx model chunk{i} saved to {dirname}" + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
|
| 92 |
+
|
| 93 |
+
else:
|
| 94 |
+
args = model.args
|
| 95 |
+
if not args.USE_EMBEDDING:
|
| 96 |
+
model.emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + ".emb")
|
| 97 |
+
args = model.args
|
| 98 |
+
fp16 = args.fp16
|
| 99 |
+
in0 = torch.LongTensor([[1]*seq_length]) if args.USE_EMBEDDING else torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)
|
| 100 |
+
states = []
|
| 101 |
+
for i in range(model.args.n_layer):
|
| 102 |
+
states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
|
| 103 |
+
states.append(torch.zeros(model.args.n_head, model.args.head_size, model.args.head_size, dtype=torch.float16 if fp16 else torch.float32))
|
| 104 |
+
states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
|
| 105 |
+
if model.device is not torch.device('cpu'):
|
| 106 |
+
states = [tensor.to(model.device) for tensor in states]
|
| 107 |
+
inputs = {'in0': in0, 'state': states}
|
| 108 |
+
input_names = ['in'] + [f'state{i}_in' for i in range(3*model.args.n_layer)]
|
| 109 |
+
output_names = ['logits'] + [f'state{i}_out' for i in range(3*model.args.n_layer)]
|
| 110 |
+
torch.onnx.export(model, inputs, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", input_names=input_names, output_names=output_names, opset_version=17)
|
| 111 |
+
shape_inference.infer_shapes_path("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
|
| 112 |
+
onnx_model = onnx.load("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
|
| 113 |
+
|
| 114 |
+
# To make model compatible with other frameworks
|
| 115 |
+
for initializer in onnx_model.graph.initializer:
|
| 116 |
+
shape = list(initializer.dims)
|
| 117 |
+
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
|
| 118 |
+
onnx_model.graph.value_info.append(value_info)
|
| 119 |
+
onnx.save_model(onnx_model, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True)
|
| 120 |
+
print(f"onnx model saved to onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
|
inference.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import onnxruntime as ort # Uncomment this line to use onnxruntime
|
| 2 |
+
import ztu_somemodelruntime_rknnlite2 as ort # Uncomment this line to use rknnlite2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from rwkv_tokenizer import RWKV_TOKENIZER
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
class RWKVModel:
|
| 9 |
+
def __init__(self, model_path: str, tokenizer_path: str = None, use_external_embedding: bool = False):
|
| 10 |
+
# 加载ONNX模型
|
| 11 |
+
session_options = ort.SessionOptions()
|
| 12 |
+
# session_options.core_mask = 7 # 00000111 使用0,1,2三个核心
|
| 13 |
+
self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'], session_options=session_options)
|
| 14 |
+
|
| 15 |
+
# 打印模型输入信息
|
| 16 |
+
print("\nModel inputs:")
|
| 17 |
+
for inp in self.session.get_inputs():
|
| 18 |
+
print(f"{inp.name}: shape={inp.shape}, type={inp.type}")
|
| 19 |
+
|
| 20 |
+
# 获取模型信息
|
| 21 |
+
self.n_layer = len([x for x in self.session.get_inputs() if 'state' in x.name]) // 3
|
| 22 |
+
self.n_embd = self.session.get_inputs()[0].shape[-1] if not use_external_embedding else None
|
| 23 |
+
|
| 24 |
+
# 从模型中获取状态向量的维度
|
| 25 |
+
self.state_shapes = {}
|
| 26 |
+
for inp in self.session.get_inputs():
|
| 27 |
+
if 'state' in inp.name:
|
| 28 |
+
self.state_shapes[inp.name] = inp.shape
|
| 29 |
+
|
| 30 |
+
print("\nNumber of layers:", self.n_layer)
|
| 31 |
+
|
| 32 |
+
# 加载tokenizer
|
| 33 |
+
if tokenizer_path:
|
| 34 |
+
self.tokenizer = RWKV_TOKENIZER(tokenizer_path)
|
| 35 |
+
else:
|
| 36 |
+
self.tokenizer = None
|
| 37 |
+
|
| 38 |
+
# 加载外部embedding(如果需要)
|
| 39 |
+
self.use_external_embedding = use_external_embedding
|
| 40 |
+
if use_external_embedding:
|
| 41 |
+
emb_path = Path(model_path).parent / (Path(model_path).stem + '.emb')
|
| 42 |
+
self.embedding = np.fromfile(emb_path, dtype=np.float32)
|
| 43 |
+
# 重新组织embedding数组的形状
|
| 44 |
+
vocab_size = len(self.embedding) // 768 # 假设embedding维度是768
|
| 45 |
+
self.embedding = self.embedding.reshape(vocab_size, 768)
|
| 46 |
+
self.n_embd = 768
|
| 47 |
+
print(f"\nEmbedding shape: {self.embedding.shape}")
|
| 48 |
+
|
| 49 |
+
# 初始化状态
|
| 50 |
+
self.reset_state()
|
| 51 |
+
|
| 52 |
+
def reset_state(self):
|
| 53 |
+
"""重置所有状态为0"""
|
| 54 |
+
self.states = []
|
| 55 |
+
for i in range(self.n_layer * 3):
|
| 56 |
+
state_name = f'state{i}_in'
|
| 57 |
+
state_shape = self.state_shapes[state_name]
|
| 58 |
+
self.states.append(np.zeros(state_shape, dtype=np.float32))
|
| 59 |
+
|
| 60 |
+
def _prepare_inputs(self, token_id):
|
| 61 |
+
"""准备模型输入"""
|
| 62 |
+
inputs = {}
|
| 63 |
+
|
| 64 |
+
# 准备主输入
|
| 65 |
+
if self.use_external_embedding:
|
| 66 |
+
# 使用外部embedding
|
| 67 |
+
embedding = self.embedding[token_id].reshape(1, 1, self.n_embd)
|
| 68 |
+
inputs['in'] = embedding.astype(np.float32)
|
| 69 |
+
else:
|
| 70 |
+
# 使用token id
|
| 71 |
+
inputs['in'] = np.array([[token_id]], dtype=np.int64)
|
| 72 |
+
|
| 73 |
+
# 添加状态
|
| 74 |
+
for i in range(len(self.states)):
|
| 75 |
+
inputs[f'state{i}_in'] = self.states[i]
|
| 76 |
+
|
| 77 |
+
# 打印输入shape
|
| 78 |
+
if token_id == 0: # 只打印第一个token的信息
|
| 79 |
+
print("\nPrepared input shapes:")
|
| 80 |
+
for k, v in inputs.items():
|
| 81 |
+
print(f"{k}: shape={v.shape}, type={v.dtype}")
|
| 82 |
+
|
| 83 |
+
return inputs
|
| 84 |
+
|
| 85 |
+
def forward(self, token_id):
|
| 86 |
+
"""单步推理"""
|
| 87 |
+
# 准备输入
|
| 88 |
+
inputs = self._prepare_inputs(token_id)
|
| 89 |
+
|
| 90 |
+
# 运行推理
|
| 91 |
+
outputs = self.session.run(None, inputs)
|
| 92 |
+
|
| 93 |
+
# 打印输出信息(仅第一次)
|
| 94 |
+
if token_id == 0:
|
| 95 |
+
print("\nModel outputs:")
|
| 96 |
+
for i, out in enumerate(outputs):
|
| 97 |
+
print(f"Output {i}: shape={out.shape}, type={out.dtype}")
|
| 98 |
+
|
| 99 |
+
# 更新状态
|
| 100 |
+
for i in range(len(self.states)):
|
| 101 |
+
new_state = outputs[i + 1] # 第一个输出是logits
|
| 102 |
+
# 确保维度匹配
|
| 103 |
+
if new_state.shape != self.states[i].shape:
|
| 104 |
+
if token_id == 0:
|
| 105 |
+
print(f"\nState shape mismatch for state{i}_in:")
|
| 106 |
+
print(f"Expected: {self.states[i].shape}")
|
| 107 |
+
print(f"Got: {new_state.shape}")
|
| 108 |
+
# 处理维度
|
| 109 |
+
if len(self.states[i].shape) == 2: # (1, 768)
|
| 110 |
+
new_state = new_state.squeeze(1) # (1, 1, 768) -> (1, 768)
|
| 111 |
+
elif len(self.states[i].shape) == 3: # (12, 64, 64)
|
| 112 |
+
new_state = new_state.squeeze(0) # (1, 12, 64, 64) -> (12, 64, 64)
|
| 113 |
+
self.states[i] = new_state
|
| 114 |
+
|
| 115 |
+
return outputs[0] # 返回logits
|
| 116 |
+
|
| 117 |
+
def generate(self, prompt: str, max_length: int = 100, temperature: float = 1.0, stop_tokens: set = None):
|
| 118 |
+
"""生成文本"""
|
| 119 |
+
if not self.tokenizer:
|
| 120 |
+
raise ValueError("需要提供tokenizer才能进行文本生成")
|
| 121 |
+
|
| 122 |
+
# 编码prompt
|
| 123 |
+
tokens = self.tokenizer.encode(prompt)
|
| 124 |
+
generated = list(tokens)
|
| 125 |
+
|
| 126 |
+
# 重置状态
|
| 127 |
+
self.reset_state()
|
| 128 |
+
|
| 129 |
+
# 处理prompt
|
| 130 |
+
print("\nProcessing prompt...", end='', flush=True)
|
| 131 |
+
t_start = time.time()
|
| 132 |
+
for token in tokens:
|
| 133 |
+
logits = self.forward(token)
|
| 134 |
+
t_prompt = time.time() - t_start
|
| 135 |
+
print(f" Done. ({len(tokens)} tokens, {t_prompt:.2f}s, {len(tokens)/t_prompt:.2f} tokens/s)")
|
| 136 |
+
|
| 137 |
+
# 生成新token
|
| 138 |
+
print("\nGenerating:", end='', flush=True)
|
| 139 |
+
t_start = time.time()
|
| 140 |
+
generated_tokens = 0
|
| 141 |
+
|
| 142 |
+
for i in range(max_length):
|
| 143 |
+
# 获取logits并应用temperature
|
| 144 |
+
t_token_start = time.time()
|
| 145 |
+
logits = self.forward(generated[-1])
|
| 146 |
+
|
| 147 |
+
# 打印第一次生成的logits信息
|
| 148 |
+
if i == 0:
|
| 149 |
+
print(f"\nLogits shape: {logits.shape}")
|
| 150 |
+
|
| 151 |
+
# 确保logits是1维的
|
| 152 |
+
logits = logits.reshape(-1) # 展平成1维
|
| 153 |
+
|
| 154 |
+
if temperature > 0:
|
| 155 |
+
# 应用temperature并计算概率
|
| 156 |
+
logits = logits / temperature
|
| 157 |
+
# 减去最大值以避免exp溢出
|
| 158 |
+
logits = logits - np.max(logits)
|
| 159 |
+
probs = np.exp(logits)
|
| 160 |
+
probs = probs / np.sum(probs)
|
| 161 |
+
next_token = np.random.choice(len(probs), p=probs)
|
| 162 |
+
else:
|
| 163 |
+
next_token = np.argmax(logits)
|
| 164 |
+
|
| 165 |
+
generated.append(next_token)
|
| 166 |
+
generated_tokens += 1
|
| 167 |
+
|
| 168 |
+
# 检查是否生成了停止标记
|
| 169 |
+
if stop_tokens and next_token in stop_tokens:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
# 实时输出新生成的token
|
| 173 |
+
new_text = self.tokenizer.decode([next_token])
|
| 174 |
+
print(new_text, end='', flush=True)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
t_generate = time.time() - t_start
|
| 179 |
+
print(f"\n\nGeneration finished: {generated_tokens} tokens generated in {t_generate:.2f}s ({generated_tokens/t_generate:.2f} tokens/s)")
|
| 180 |
+
|
| 181 |
+
return self.tokenizer.decode(generated)
|
| 182 |
+
|
| 183 |
+
def main():
|
| 184 |
+
import time
|
| 185 |
+
|
| 186 |
+
# 使用示例
|
| 187 |
+
print("Loading model...")
|
| 188 |
+
t_start = time.time()
|
| 189 |
+
model = RWKVModel(
|
| 190 |
+
model_path='RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx',
|
| 191 |
+
tokenizer_path='rwkv_vocab_v20230424.txt',
|
| 192 |
+
use_external_embedding=True
|
| 193 |
+
)
|
| 194 |
+
print(f"Model loaded in {time.time() - t_start:.2f}s")
|
| 195 |
+
|
| 196 |
+
prompt = "Here is a example of Quick Sort algorithm implemented in C++:\n```cpp"
|
| 197 |
+
print(f"\nPrompt: {prompt}")
|
| 198 |
+
|
| 199 |
+
generated_text = model.generate(
|
| 200 |
+
prompt=prompt,
|
| 201 |
+
max_length=1024,
|
| 202 |
+
temperature=0.7,
|
| 203 |
+
stop_tokens={0, 1, 2, 3} # 特殊token作为停止标记
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
print("\nFull text:")
|
| 207 |
+
print(generated_text)
|
| 208 |
+
|
| 209 |
+
if __name__ == '__main__':
|
| 210 |
+
main()
|
rwkv_tokenizer.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List,Set,Dict
|
| 2 |
+
|
| 3 |
+
class ABCTokenizer():
|
| 4 |
+
def __init__(self):
|
| 5 |
+
self.pad_token_id = 0
|
| 6 |
+
self.bos_token_id = 2
|
| 7 |
+
self.eos_token_id = 3
|
| 8 |
+
def encode(self, text):
|
| 9 |
+
ids = [ord(c) for c in text]
|
| 10 |
+
return ids
|
| 11 |
+
def decode(self, ids):
|
| 12 |
+
txt = ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in ids if idx != self.eos_token_id)
|
| 13 |
+
return txt
|
| 14 |
+
|
| 15 |
+
class RWKV_TOKENIZER():
|
| 16 |
+
table: List[List[List[bytes]]]
|
| 17 |
+
good: List[Set[int]]
|
| 18 |
+
wlen: List[int]
|
| 19 |
+
def __init__(self, file_name):
|
| 20 |
+
self.idx2token = {}
|
| 21 |
+
sorted = [] # must be already sorted
|
| 22 |
+
lines = open(file_name, "r", encoding="utf-8").readlines()
|
| 23 |
+
for l in lines:
|
| 24 |
+
idx = int(l[:l.index(' ')])
|
| 25 |
+
x = eval(l[l.index(' '):l.rindex(' ')])
|
| 26 |
+
x = x.encode("utf-8") if isinstance(x, str) else x
|
| 27 |
+
assert isinstance(x, bytes)
|
| 28 |
+
assert len(x) == int(l[l.rindex(' '):])
|
| 29 |
+
sorted += [x]
|
| 30 |
+
self.idx2token[idx] = x
|
| 31 |
+
|
| 32 |
+
self.token2idx = {}
|
| 33 |
+
for k, v in self.idx2token.items():
|
| 34 |
+
self.token2idx[v] = int(k)
|
| 35 |
+
|
| 36 |
+
# precompute some tables for fast matching
|
| 37 |
+
self.table = [[[] for j in range(256)] for i in range(256)]
|
| 38 |
+
self.good = [set() for i in range(256)]
|
| 39 |
+
self.wlen = [0 for i in range(256)]
|
| 40 |
+
|
| 41 |
+
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
|
| 42 |
+
s = sorted[i]
|
| 43 |
+
if len(s) >= 2:
|
| 44 |
+
s0 = int(s[0])
|
| 45 |
+
s1 = int(s[1])
|
| 46 |
+
self.table[s0][s1] += [s]
|
| 47 |
+
self.wlen[s0] = max(self.wlen[s0], len(s))
|
| 48 |
+
self.good[s0].add(s1)
|
| 49 |
+
|
| 50 |
+
def encodeBytes(self, src: bytes) -> List[int]:
|
| 51 |
+
src_len: int = len(src)
|
| 52 |
+
tokens: List[int] = []
|
| 53 |
+
i: int = 0
|
| 54 |
+
while i < src_len:
|
| 55 |
+
s: bytes = src[i : i + 1]
|
| 56 |
+
|
| 57 |
+
if i < src_len - 1:
|
| 58 |
+
s1: int = int(src[i + 1])
|
| 59 |
+
s0: int = int(src[i])
|
| 60 |
+
if s1 in self.good[s0]:
|
| 61 |
+
sss: bytes = src[i : i + self.wlen[s0]]
|
| 62 |
+
try:
|
| 63 |
+
s = next(filter(sss.startswith, self.table[s0][s1]))
|
| 64 |
+
except:
|
| 65 |
+
pass
|
| 66 |
+
tokens.append(self.token2idx[s])
|
| 67 |
+
i += len(s)
|
| 68 |
+
|
| 69 |
+
return tokens
|
| 70 |
+
|
| 71 |
+
def decodeBytes(self, tokens):
|
| 72 |
+
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
| 73 |
+
|
| 74 |
+
def encode(self, src: str):
|
| 75 |
+
return self.encodeBytes(src.encode("utf-8"))
|
| 76 |
+
|
| 77 |
+
def decode(self, tokens):
|
| 78 |
+
return self.decodeBytes(tokens).decode('utf-8')
|
| 79 |
+
|
| 80 |
+
def printTokens(self, tokens):
|
| 81 |
+
for i in tokens:
|
| 82 |
+
s = self.idx2token[i]
|
| 83 |
+
try:
|
| 84 |
+
s = s.decode('utf-8')
|
| 85 |
+
except:
|
| 86 |
+
pass
|
| 87 |
+
print(f'{repr(s)}{i}', end=' ')
|
| 88 |
+
# print(repr(s), i)
|
| 89 |
+
print()
|
rwkv_vocab_v20230424.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ztu_somemodelruntime_rknnlite2.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 模块级常量和函数
|
| 2 |
+
from rknnlite.api import RKNNLite
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import warnings
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Dict, Union, Optional
|
| 8 |
+
|
| 9 |
+
# 配置日志
|
| 10 |
+
logger = logging.getLogger("somemodelruntime_rknnlite2")
|
| 11 |
+
logger.setLevel(logging.ERROR) # 默认只输出错误信息
|
| 12 |
+
if not logger.handlers:
|
| 13 |
+
handler = logging.StreamHandler()
|
| 14 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 15 |
+
logger.addHandler(handler)
|
| 16 |
+
|
| 17 |
+
# ONNX Runtime日志级别到Python logging级别的映射
|
| 18 |
+
_LOGGING_LEVEL_MAP = {
|
| 19 |
+
0: logging.DEBUG, # Verbose
|
| 20 |
+
1: logging.INFO, # Info
|
| 21 |
+
2: logging.WARNING, # Warning
|
| 22 |
+
3: logging.ERROR, # Error
|
| 23 |
+
4: logging.CRITICAL # Fatal
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def set_default_logger_severity(level: int) -> None:
|
| 27 |
+
"""
|
| 28 |
+
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
level: 日志级别(0-4)
|
| 32 |
+
"""
|
| 33 |
+
if level not in _LOGGING_LEVEL_MAP:
|
| 34 |
+
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
|
| 35 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[level])
|
| 36 |
+
|
| 37 |
+
def set_default_logger_verbosity(level: int) -> None:
|
| 38 |
+
"""
|
| 39 |
+
Sets the default logging verbosity level. To activate the verbose log,
|
| 40 |
+
you need to set the default logging severity to 0:Verbose level.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
level: 日志级别(0-4)
|
| 44 |
+
"""
|
| 45 |
+
set_default_logger_severity(level)
|
| 46 |
+
|
| 47 |
+
# NPU核心模式常量
|
| 48 |
+
NPU_CORE_AUTO = 0 # 自动选择
|
| 49 |
+
NPU_CORE_0 = 1 # 使用核心0
|
| 50 |
+
NPU_CORE_1 = 2 # 使用核心1
|
| 51 |
+
NPU_CORE_2 = 4 # 使用核心2
|
| 52 |
+
NPU_CORE_0_1 = 3 # 使用核心0和1
|
| 53 |
+
NPU_CORE_0_1_2 = 7 # 使用所有核心
|
| 54 |
+
NPU_CORE_ALL = 0xffff # 使用所有核心
|
| 55 |
+
|
| 56 |
+
# RKNN tensor type到numpy dtype的映射
|
| 57 |
+
RKNN_DTYPE_MAP = {
|
| 58 |
+
0: np.float32, # RKNN_TENSOR_FLOAT32
|
| 59 |
+
1: np.float16, # RKNN_TENSOR_FLOAT16
|
| 60 |
+
2: np.int8, # RKNN_TENSOR_INT8
|
| 61 |
+
3: np.uint8, # RKNN_TENSOR_UINT8
|
| 62 |
+
4: np.int16, # RKNN_TENSOR_INT16
|
| 63 |
+
5: np.uint16, # RKNN_TENSOR_UINT16
|
| 64 |
+
6: np.int32, # RKNN_TENSOR_INT32
|
| 65 |
+
7: np.uint32, # RKNN_TENSOR_UINT32
|
| 66 |
+
8: np.int64, # RKNN_TENSOR_INT64
|
| 67 |
+
9: bool, # RKNN_TENSOR_BOOL
|
| 68 |
+
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def get_available_providers() -> List[str]:
|
| 72 |
+
"""
|
| 73 |
+
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider"]
|
| 77 |
+
"""
|
| 78 |
+
return ["CPUExecutionProvider"]
|
| 79 |
+
|
| 80 |
+
def get_version_info() -> Dict[str, str]:
|
| 81 |
+
"""
|
| 82 |
+
获取版本信息
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
dict: 包含API和驱动版本信息的字典
|
| 86 |
+
"""
|
| 87 |
+
runtime = RKNNLite()
|
| 88 |
+
version = runtime.get_sdk_version()
|
| 89 |
+
return {
|
| 90 |
+
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
|
| 91 |
+
"driver_version": version.split('\n')[3].split(': ')[1]
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
class IOTensor:
|
| 95 |
+
"""输入/输出张量的信息封装类"""
|
| 96 |
+
def __init__(self, name, shape, type=None):
|
| 97 |
+
self.name = name.decode() if isinstance(name, bytes) else name
|
| 98 |
+
self.shape = shape
|
| 99 |
+
self.type = type
|
| 100 |
+
|
| 101 |
+
def __str__(self):
|
| 102 |
+
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
|
| 103 |
+
|
| 104 |
+
class SessionOptions:
|
| 105 |
+
"""会话选项类"""
|
| 106 |
+
def __init__(self):
|
| 107 |
+
self.async_mode = False # 是否使用异步模式
|
| 108 |
+
self.core_mask = 0 # NPU核心选择
|
| 109 |
+
self.perf_debug = False # 是否启用性能分析
|
| 110 |
+
|
| 111 |
+
class InferenceSession:
|
| 112 |
+
"""
|
| 113 |
+
RKNNLite运行时封装类,API风格类似ONNX Runtime
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, model_path: str, verbose: bool = False, session_options: Optional[SessionOptions] = None, **kwargs):
|
| 117 |
+
"""
|
| 118 |
+
初始化运行时并加载模型
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
| 122 |
+
verbose: 是否打印详细日志
|
| 123 |
+
session_options: 会话选项
|
| 124 |
+
**kwargs: 其他初始化参数
|
| 125 |
+
"""
|
| 126 |
+
# 只在verbose=True时开启详细日志
|
| 127 |
+
if verbose:
|
| 128 |
+
set_default_logger_severity(0) # Verbose
|
| 129 |
+
|
| 130 |
+
self.model_path = self._process_model_path(model_path)
|
| 131 |
+
self.runtime = RKNNLite(verbose=verbose)
|
| 132 |
+
|
| 133 |
+
# 加载模型
|
| 134 |
+
logger.debug(f"正在加载模型: {self.model_path}")
|
| 135 |
+
ret = self.runtime.load_rknn(self.model_path)
|
| 136 |
+
if ret != 0:
|
| 137 |
+
logger.error(f"加载RKNN模型失败: {self.model_path}")
|
| 138 |
+
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
|
| 139 |
+
logger.debug("模型加载成功")
|
| 140 |
+
|
| 141 |
+
# 应用会话选项
|
| 142 |
+
options = session_options or SessionOptions()
|
| 143 |
+
|
| 144 |
+
# 初始化运行时
|
| 145 |
+
logger.debug("正在初始化运行时环境")
|
| 146 |
+
ret = self.runtime.init_runtime(
|
| 147 |
+
async_mode=options.async_mode,
|
| 148 |
+
core_mask=options.core_mask
|
| 149 |
+
)
|
| 150 |
+
if ret != 0:
|
| 151 |
+
logger.error("初始化运行时环境失败")
|
| 152 |
+
raise RuntimeError('初始化运行时环境失败')
|
| 153 |
+
logger.debug("运行时环境初始化成功")
|
| 154 |
+
|
| 155 |
+
# 获取输入输出信息
|
| 156 |
+
self._init_io_info()
|
| 157 |
+
|
| 158 |
+
# 保存选项
|
| 159 |
+
self.options = options
|
| 160 |
+
|
| 161 |
+
def get_performance_info(self) -> Dict[str, float]:
|
| 162 |
+
"""
|
| 163 |
+
获取性能信息
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
dict: 包含性能信息的字典
|
| 167 |
+
"""
|
| 168 |
+
if not self.options.perf_debug:
|
| 169 |
+
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
|
| 170 |
+
|
| 171 |
+
perf = self.runtime.rknn_runtime.get_run_perf()
|
| 172 |
+
return {
|
| 173 |
+
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
def set_core_mask(self, core_mask: int) -> None:
|
| 177 |
+
"""
|
| 178 |
+
设置NPU核心使用模式
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
core_mask: NPU核心掩码,使用NPU_CORE_*常量
|
| 182 |
+
"""
|
| 183 |
+
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
|
| 184 |
+
if ret != 0:
|
| 185 |
+
raise RuntimeError("设置NPU核心模式失败")
|
| 186 |
+
|
| 187 |
+
def _process_model_path(self, model_path):
|
| 188 |
+
"""处理模型路径,支持.onnx和.rknn文件"""
|
| 189 |
+
if not os.path.exists(model_path):
|
| 190 |
+
logger.error(f"模型文件不存在: {model_path}")
|
| 191 |
+
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
| 192 |
+
|
| 193 |
+
# 如果是ONNX文件
|
| 194 |
+
if model_path.lower().endswith('.onnx'):
|
| 195 |
+
logger.warning(
|
| 196 |
+
"检测到ONNX模型文件。注意:SomeModelRuntime不会自动转换ONNX到RKNN。"
|
| 197 |
+
"请先使用RKNN Toolkit转换模型。"
|
| 198 |
+
"现在尝试加载同名的.rknn文件。"
|
| 199 |
+
)
|
| 200 |
+
# 构造RKNN文件路径
|
| 201 |
+
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
|
| 202 |
+
if not os.path.exists(rknn_path):
|
| 203 |
+
logger.error(f"RKNN模型文件不存在: {rknn_path}")
|
| 204 |
+
raise FileNotFoundError(
|
| 205 |
+
f"RKNN模型文件不存在: {rknn_path}\n"
|
| 206 |
+
"请先使用RKNN Toolkit将ONNX模型转换为RKNN格式。"
|
| 207 |
+
)
|
| 208 |
+
return rknn_path
|
| 209 |
+
|
| 210 |
+
return model_path
|
| 211 |
+
|
| 212 |
+
def _convert_nhwc_to_nchw(self, shape):
|
| 213 |
+
"""将NHWC格式的shape转换为NCHW格式"""
|
| 214 |
+
if len(shape) == 4:
|
| 215 |
+
# NHWC -> NCHW
|
| 216 |
+
n, h, w, c = shape
|
| 217 |
+
return [n, c, h, w]
|
| 218 |
+
return shape
|
| 219 |
+
|
| 220 |
+
def _init_io_info(self):
|
| 221 |
+
"""初始化模型的输入输出信息"""
|
| 222 |
+
runtime = self.runtime.rknn_runtime
|
| 223 |
+
|
| 224 |
+
# 获取输入输出数量
|
| 225 |
+
n_input, n_output = runtime.get_in_out_num()
|
| 226 |
+
|
| 227 |
+
# 获取输入信息
|
| 228 |
+
self.input_tensors = []
|
| 229 |
+
for i in range(n_input):
|
| 230 |
+
attr = runtime.get_tensor_attr(i)
|
| 231 |
+
shape = [attr.dims[j] for j in range(attr.n_dims)]
|
| 232 |
+
# 对四维输入进行NHWC到NCHW的转换
|
| 233 |
+
shape = self._convert_nhwc_to_nchw(shape)
|
| 234 |
+
# 获取dtype
|
| 235 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 236 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 237 |
+
self.input_tensors.append(tensor)
|
| 238 |
+
|
| 239 |
+
# 获取输出信息
|
| 240 |
+
self.output_tensors = []
|
| 241 |
+
for i in range(n_output):
|
| 242 |
+
attr = runtime.get_tensor_attr(i, is_output=True)
|
| 243 |
+
shape = runtime.get_output_shape(i)
|
| 244 |
+
# 获取dtype
|
| 245 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 246 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 247 |
+
self.output_tensors.append(tensor)
|
| 248 |
+
|
| 249 |
+
def get_inputs(self):
|
| 250 |
+
"""
|
| 251 |
+
获取模型输入信息
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
list: 包含输入信息的列表
|
| 255 |
+
"""
|
| 256 |
+
return self.input_tensors
|
| 257 |
+
|
| 258 |
+
def get_outputs(self):
|
| 259 |
+
"""
|
| 260 |
+
获取模型输出信息
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
list: 包含输出信息的列表
|
| 264 |
+
"""
|
| 265 |
+
return self.output_tensors
|
| 266 |
+
|
| 267 |
+
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
|
| 268 |
+
"""
|
| 269 |
+
执行模型推理
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
output_names: 输出节点名称列表,指定需要返回哪些输出
|
| 273 |
+
input_feed: 输入数据字典或列表
|
| 274 |
+
data_format: 输入数据格式,"nchw"或"nhwc"
|
| 275 |
+
**kwargs: 其他运行时参数
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
|
| 279 |
+
"""
|
| 280 |
+
if input_feed is None:
|
| 281 |
+
logger.error("input_feed不能为None")
|
| 282 |
+
raise ValueError("input_feed不能为None")
|
| 283 |
+
|
| 284 |
+
# 准备输入数据
|
| 285 |
+
if isinstance(input_feed, dict):
|
| 286 |
+
# 如果是字典,按照模型输入顺序排列
|
| 287 |
+
inputs = []
|
| 288 |
+
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
|
| 289 |
+
for tensor in self.input_tensors:
|
| 290 |
+
if tensor.name not in input_feed:
|
| 291 |
+
raise ValueError(f"缺少输入: {tensor.name}")
|
| 292 |
+
inputs.append(input_feed[tensor.name])
|
| 293 |
+
elif isinstance(input_feed, (list, tuple)):
|
| 294 |
+
# 如果是列表,确保长度匹配
|
| 295 |
+
if len(input_feed) != len(self.input_tensors):
|
| 296 |
+
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
|
| 297 |
+
inputs = list(input_feed)
|
| 298 |
+
else:
|
| 299 |
+
logger.error("input_feed必须是字典或列表类型")
|
| 300 |
+
raise ValueError("input_feed必须是字典或列表类型")
|
| 301 |
+
|
| 302 |
+
# 执行推理
|
| 303 |
+
try:
|
| 304 |
+
logger.debug("开始执行推理")
|
| 305 |
+
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
|
| 306 |
+
|
| 307 |
+
# 如果没有指定output_names,返回所有输出
|
| 308 |
+
if output_names is None:
|
| 309 |
+
return all_outputs
|
| 310 |
+
|
| 311 |
+
# 获取指定的输出
|
| 312 |
+
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
|
| 313 |
+
selected_outputs = []
|
| 314 |
+
for name in output_names:
|
| 315 |
+
if name not in output_map:
|
| 316 |
+
raise ValueError(f"未找到输出节点: {name}")
|
| 317 |
+
selected_outputs.append(all_outputs[output_map[name]])
|
| 318 |
+
|
| 319 |
+
return selected_outputs
|
| 320 |
+
|
| 321 |
+
except Exception as e:
|
| 322 |
+
logger.error(f"推理执行失败: {str(e)}")
|
| 323 |
+
raise RuntimeError(f"推理执行失败: {str(e)}")
|
| 324 |
+
|
| 325 |
+
def close(self):
|
| 326 |
+
"""
|
| 327 |
+
关闭会话,释放资源
|
| 328 |
+
"""
|
| 329 |
+
if self.runtime is not None:
|
| 330 |
+
logger.info("正在释放运行时资源")
|
| 331 |
+
self.runtime.release()
|
| 332 |
+
self.runtime = None
|
| 333 |
+
|
| 334 |
+
def __enter__(self):
|
| 335 |
+
return self
|
| 336 |
+
|
| 337 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 338 |
+
self.close()
|
| 339 |
+
|
| 340 |
+
def end_profiling(self) -> Optional[str]:
|
| 341 |
+
"""
|
| 342 |
+
结束性能分析的存根方法
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
Optional[str]: None
|
| 346 |
+
"""
|
| 347 |
+
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 348 |
+
return None
|
| 349 |
+
|
| 350 |
+
def get_profiling_start_time_ns(self) -> int:
|
| 351 |
+
"""
|
| 352 |
+
获取性能分析开始时间的存根方法
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
int: 0
|
| 356 |
+
"""
|
| 357 |
+
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 358 |
+
return 0
|
| 359 |
+
|
| 360 |
+
def get_modelmeta(self) -> Dict[str, str]:
|
| 361 |
+
"""
|
| 362 |
+
获取模型元数据的存根方法
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Dict[str, str]: 空字典
|
| 366 |
+
"""
|
| 367 |
+
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 368 |
+
return {}
|
| 369 |
+
|
| 370 |
+
def get_session_options(self) -> SessionOptions:
|
| 371 |
+
"""
|
| 372 |
+
获取会话选项
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
SessionOptions: 当前会话选项
|
| 376 |
+
"""
|
| 377 |
+
return self.options
|
| 378 |
+
|
| 379 |
+
def get_providers(self) -> List[str]:
|
| 380 |
+
"""
|
| 381 |
+
获取当前使用的providers的存根方法
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
List[str]: ["CPUExecutionProvider"]
|
| 385 |
+
"""
|
| 386 |
+
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
|
| 387 |
+
return ["CPUExecutionProvider"]
|
| 388 |
+
|
| 389 |
+
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
|
| 390 |
+
"""
|
| 391 |
+
获取provider选项的存根方法
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
Dict[str, Dict[str, str]]: 空字典
|
| 395 |
+
"""
|
| 396 |
+
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 397 |
+
return {}
|
| 398 |
+
|
| 399 |
+
def get_session_config(self) -> Dict[str, str]:
|
| 400 |
+
"""
|
| 401 |
+
获取会话配置的存根方法
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
Dict[str, str]: 空字典
|
| 405 |
+
"""
|
| 406 |
+
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 407 |
+
return {}
|
| 408 |
+
|
| 409 |
+
def get_session_state(self) -> Dict[str, str]:
|
| 410 |
+
"""
|
| 411 |
+
获取会话状态的存根方法
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
Dict[str, str]: 空字典
|
| 415 |
+
"""
|
| 416 |
+
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 417 |
+
return {}
|
| 418 |
+
|
| 419 |
+
def set_session_config(self, config: Dict[str, str]) -> None:
|
| 420 |
+
"""
|
| 421 |
+
设置会话配置的存根方法
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
config: 会话配置字典
|
| 425 |
+
"""
|
| 426 |
+
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 427 |
+
|
| 428 |
+
def get_memory_info(self) -> Dict[str, int]:
|
| 429 |
+
"""
|
| 430 |
+
获取内存使用信息的存根方法
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
Dict[str, int]: 空字典
|
| 434 |
+
"""
|
| 435 |
+
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 436 |
+
return {}
|
| 437 |
+
|
| 438 |
+
def set_memory_pattern(self, enable: bool) -> None:
|
| 439 |
+
"""
|
| 440 |
+
设置内存模式的存根方法
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
enable: 是否启用内存模式
|
| 444 |
+
"""
|
| 445 |
+
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 446 |
+
|
| 447 |
+
def disable_memory_pattern(self) -> None:
|
| 448 |
+
"""
|
| 449 |
+
禁用内存模式的存根方法
|
| 450 |
+
"""
|
| 451 |
+
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 452 |
+
|
| 453 |
+
def get_optimization_level(self) -> int:
|
| 454 |
+
"""
|
| 455 |
+
获取优化级别的存根方法
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
int: 0
|
| 459 |
+
"""
|
| 460 |
+
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 461 |
+
return 0
|
| 462 |
+
|
| 463 |
+
def set_optimization_level(self, level: int) -> None:
|
| 464 |
+
"""
|
| 465 |
+
设置优化级别的存根方法
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
level: 优化级别
|
| 469 |
+
"""
|
| 470 |
+
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 471 |
+
|
| 472 |
+
def get_model_metadata(self) -> Dict[str, str]:
|
| 473 |
+
"""
|
| 474 |
+
获取模型元数据的存根方法(与get_modelmeta不同的接口)
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
Dict[str, str]: 空字典
|
| 478 |
+
"""
|
| 479 |
+
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 480 |
+
return {}
|
| 481 |
+
|
| 482 |
+
def get_model_path(self) -> str:
|
| 483 |
+
"""
|
| 484 |
+
获取模型路径
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
str: 模型文件路径
|
| 488 |
+
"""
|
| 489 |
+
return self.model_path
|
| 490 |
+
|
| 491 |
+
def get_input_type_info(self) -> List[Dict[str, str]]:
|
| 492 |
+
"""
|
| 493 |
+
获取输入类型信息的存根方法
|
| 494 |
+
|
| 495 |
+
Returns:
|
| 496 |
+
List[Dict[str, str]]: 空列表
|
| 497 |
+
"""
|
| 498 |
+
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 499 |
+
return []
|
| 500 |
+
|
| 501 |
+
def get_output_type_info(self) -> List[Dict[str, str]]:
|
| 502 |
+
"""
|
| 503 |
+
获取输出类型信息的存根方法
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
List[Dict[str, str]]: 空列表
|
| 507 |
+
"""
|
| 508 |
+
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 509 |
+
return []
|