fix: NHWC data_format for UNet/VAE decoder under librknnrt 2.3.2

#6
Files changed (1) hide show
  1. run_rknn-lcm.py +33 -9
run_rknn-lcm.py CHANGED
@@ -30,9 +30,21 @@ from rknnlite.api import RKNNLite
30
  class RKNN2Model:
31
  """ Wrapper for running RKNPU2 models """
32
 
33
- def __init__(self, model_dir):
 
 
 
 
 
 
 
 
 
 
 
34
  logger.info(f"Loading {model_dir}")
35
  start = time.time()
 
36
  self.config = json.load(open(os.path.join(model_dir, "config.json")))
37
  assert os.path.exists(model_dir) and os.path.exists(os.path.join(model_dir, "model.rknn"))
38
  self.rknnlite = RKNNLite()
@@ -44,15 +56,27 @@ class RKNN2Model:
44
  self.inference_time = 0
45
 
46
  def __call__(self, **kwargs) -> List[np.ndarray]:
47
- # np.savez(f"rknn_out/{self.modelname}_input_{self.inference_time}.npz", **kwargs)
48
- # self.inference_time += 1
49
- #print(kwargs)
50
- input_list = [value for key, value in kwargs.items()]
 
 
 
 
 
 
 
 
 
 
 
 
51
  for i, input in enumerate(input_list):
52
  if isinstance(input, np.ndarray):
53
  print(f"input {i} shape: {input.shape}")
54
 
55
- results = self.rknnlite.inference(inputs=input_list, data_format='nchw')
56
  for res in results:
57
  print(f"output shape: {res.shape}")
58
  return results
@@ -573,9 +597,9 @@ def main(args):
573
  print("user_specified_scheduler", user_specified_scheduler)
574
 
575
  pipe = RKNN2LatentConsistencyPipeline(
576
- text_encoder=RKNN2Model(os.path.join(args.i, "text_encoder")),
577
- unet=RKNN2Model(os.path.join(args.i, "unet")),
578
- vae_decoder=RKNN2Model(os.path.join(args.i, "vae_decoder")),
579
  scheduler=user_specified_scheduler,
580
  tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16"),
581
  )
 
30
  class RKNN2Model:
31
  """ Wrapper for running RKNPU2 models """
32
 
33
+ def __init__(self, model_dir, data_format: str = "nchw"):
34
+ """
35
+ Args:
36
+ model_dir: directory containing config.json + model.rknn
37
+ data_format: "nchw" or "nhwc" — how the RKNN model expects
38
+ its 4-D inputs. Under librknnrt 2.3.2 (2025-04-09) the
39
+ runtime's automatic NCHW→NHWC conversion path is no
40
+ longer reliable for this UNet, so models that expect
41
+ NHWC (the UNet and the VAE decoder in this pipeline)
42
+ must be told explicitly and their input tensors
43
+ transposed in Python before the inference call.
44
+ """
45
  logger.info(f"Loading {model_dir}")
46
  start = time.time()
47
+ self.data_format = data_format.lower()
48
  self.config = json.load(open(os.path.join(model_dir, "config.json")))
49
  assert os.path.exists(model_dir) and os.path.exists(os.path.join(model_dir, "model.rknn"))
50
  self.rknnlite = RKNNLite()
 
56
  self.inference_time = 0
57
 
58
  def __call__(self, **kwargs) -> List[np.ndarray]:
59
+ def prep(x):
60
+ if isinstance(x, np.ndarray):
61
+ # dtype safety — the runtime wants float32
62
+ if x.dtype in (np.float16, np.float64):
63
+ x = x.astype(np.float32, copy=False)
64
+ # layout safety: transpose 4-D tensors to match the
65
+ # declared data_format at the RKNN boundary
66
+ if x.ndim == 4:
67
+ if self.data_format == "nhwc" and x.shape[1] in (1, 3, 4):
68
+ x = x.transpose(0, 2, 3, 1) # NCHW -> NHWC
69
+ elif self.data_format == "nchw" and x.shape[-1] in (1, 3, 4):
70
+ x = x.transpose(0, 3, 1, 2) # NHWC -> NCHW
71
+ x = np.ascontiguousarray(x)
72
+ return x
73
+
74
+ input_list = [prep(v) for v in kwargs.values()]
75
  for i, input in enumerate(input_list):
76
  if isinstance(input, np.ndarray):
77
  print(f"input {i} shape: {input.shape}")
78
 
79
+ results = self.rknnlite.inference(inputs=input_list, data_format=self.data_format)
80
  for res in results:
81
  print(f"output shape: {res.shape}")
82
  return results
 
597
  print("user_specified_scheduler", user_specified_scheduler)
598
 
599
  pipe = RKNN2LatentConsistencyPipeline(
600
+ text_encoder=RKNN2Model(os.path.join(args.i, "text_encoder"), data_format="nchw"),
601
+ unet=RKNN2Model(os.path.join(args.i, "unet"), data_format="nhwc"),
602
+ vae_decoder=RKNN2Model(os.path.join(args.i, "vae_decoder"), data_format="nhwc"),
603
  scheduler=user_specified_scheduler,
604
  tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16"),
605
  )