fix: NHWC data_format for UNet/VAE decoder under librknnrt 2.3.2
#6
by jaylfc - opened
- run_rknn-lcm.py +33 -9
run_rknn-lcm.py
CHANGED
|
@@ -30,9 +30,21 @@ from rknnlite.api import RKNNLite
|
|
| 30 |
class RKNN2Model:
|
| 31 |
""" Wrapper for running RKNPU2 models """
|
| 32 |
|
| 33 |
-
def __init__(self, model_dir):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
logger.info(f"Loading {model_dir}")
|
| 35 |
start = time.time()
|
|
|
|
| 36 |
self.config = json.load(open(os.path.join(model_dir, "config.json")))
|
| 37 |
assert os.path.exists(model_dir) and os.path.exists(os.path.join(model_dir, "model.rknn"))
|
| 38 |
self.rknnlite = RKNNLite()
|
|
@@ -44,15 +56,27 @@ class RKNN2Model:
|
|
| 44 |
self.inference_time = 0
|
| 45 |
|
| 46 |
def __call__(self, **kwargs) -> List[np.ndarray]:
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
for i, input in enumerate(input_list):
|
| 52 |
if isinstance(input, np.ndarray):
|
| 53 |
print(f"input {i} shape: {input.shape}")
|
| 54 |
|
| 55 |
-
results = self.rknnlite.inference(inputs=input_list, data_format=
|
| 56 |
for res in results:
|
| 57 |
print(f"output shape: {res.shape}")
|
| 58 |
return results
|
|
@@ -573,9 +597,9 @@ def main(args):
|
|
| 573 |
print("user_specified_scheduler", user_specified_scheduler)
|
| 574 |
|
| 575 |
pipe = RKNN2LatentConsistencyPipeline(
|
| 576 |
-
text_encoder=RKNN2Model(os.path.join(args.i, "text_encoder")),
|
| 577 |
-
unet=RKNN2Model(os.path.join(args.i, "unet")),
|
| 578 |
-
vae_decoder=RKNN2Model(os.path.join(args.i, "vae_decoder")),
|
| 579 |
scheduler=user_specified_scheduler,
|
| 580 |
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16"),
|
| 581 |
)
|
|
|
|
| 30 |
class RKNN2Model:
|
| 31 |
""" Wrapper for running RKNPU2 models """
|
| 32 |
|
| 33 |
+
def __init__(self, model_dir, data_format: str = "nchw"):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
model_dir: directory containing config.json + model.rknn
|
| 37 |
+
data_format: "nchw" or "nhwc" — how the RKNN model expects
|
| 38 |
+
its 4-D inputs. Under librknnrt 2.3.2 (2025-04-09) the
|
| 39 |
+
runtime's automatic NCHW→NHWC conversion path is no
|
| 40 |
+
longer reliable for this UNet, so models that expect
|
| 41 |
+
NHWC (the UNet and the VAE decoder in this pipeline)
|
| 42 |
+
must be told explicitly and their input tensors
|
| 43 |
+
transposed in Python before the inference call.
|
| 44 |
+
"""
|
| 45 |
logger.info(f"Loading {model_dir}")
|
| 46 |
start = time.time()
|
| 47 |
+
self.data_format = data_format.lower()
|
| 48 |
self.config = json.load(open(os.path.join(model_dir, "config.json")))
|
| 49 |
assert os.path.exists(model_dir) and os.path.exists(os.path.join(model_dir, "model.rknn"))
|
| 50 |
self.rknnlite = RKNNLite()
|
|
|
|
| 56 |
self.inference_time = 0
|
| 57 |
|
| 58 |
def __call__(self, **kwargs) -> List[np.ndarray]:
|
| 59 |
+
def prep(x):
|
| 60 |
+
if isinstance(x, np.ndarray):
|
| 61 |
+
# dtype safety — the runtime wants float32
|
| 62 |
+
if x.dtype in (np.float16, np.float64):
|
| 63 |
+
x = x.astype(np.float32, copy=False)
|
| 64 |
+
# layout safety: transpose 4-D tensors to match the
|
| 65 |
+
# declared data_format at the RKNN boundary
|
| 66 |
+
if x.ndim == 4:
|
| 67 |
+
if self.data_format == "nhwc" and x.shape[1] in (1, 3, 4):
|
| 68 |
+
x = x.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
| 69 |
+
elif self.data_format == "nchw" and x.shape[-1] in (1, 3, 4):
|
| 70 |
+
x = x.transpose(0, 3, 1, 2) # NHWC -> NCHW
|
| 71 |
+
x = np.ascontiguousarray(x)
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
input_list = [prep(v) for v in kwargs.values()]
|
| 75 |
for i, input in enumerate(input_list):
|
| 76 |
if isinstance(input, np.ndarray):
|
| 77 |
print(f"input {i} shape: {input.shape}")
|
| 78 |
|
| 79 |
+
results = self.rknnlite.inference(inputs=input_list, data_format=self.data_format)
|
| 80 |
for res in results:
|
| 81 |
print(f"output shape: {res.shape}")
|
| 82 |
return results
|
|
|
|
| 597 |
print("user_specified_scheduler", user_specified_scheduler)
|
| 598 |
|
| 599 |
pipe = RKNN2LatentConsistencyPipeline(
|
| 600 |
+
text_encoder=RKNN2Model(os.path.join(args.i, "text_encoder"), data_format="nchw"),
|
| 601 |
+
unet=RKNN2Model(os.path.join(args.i, "unet"), data_format="nhwc"),
|
| 602 |
+
vae_decoder=RKNN2Model(os.path.join(args.i, "vae_decoder"), data_format="nhwc"),
|
| 603 |
scheduler=user_specified_scheduler,
|
| 604 |
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16"),
|
| 605 |
)
|