Commit
·
8572c72
1
Parent(s):
8d44598
fix data-format
Browse files- lcm_server.py +2 -2
- rknnlcm.py +107 -16
lcm_server.py
CHANGED
|
@@ -106,7 +106,7 @@ class PipelineWorker:
|
|
| 106 |
|
| 107 |
print("seed ", job.req.seed)
|
| 108 |
print("rng", rng)
|
| 109 |
-
|
| 110 |
result = self.pipe(
|
| 111 |
prompt=job.req.prompt,
|
| 112 |
height=h,
|
|
@@ -114,7 +114,7 @@ class PipelineWorker:
|
|
| 114 |
num_inference_steps=job.req.num_inference_steps,
|
| 115 |
guidance_scale=job.req.guidance_scale,
|
| 116 |
generator=rng,
|
| 117 |
-
)
|
| 118 |
|
| 119 |
pil_image = result["images"][0]
|
| 120 |
buf = io.BytesIO()
|
|
|
|
| 106 |
|
| 107 |
print("seed ", job.req.seed)
|
| 108 |
print("rng", rng)
|
| 109 |
+
|
| 110 |
result = self.pipe(
|
| 111 |
prompt=job.req.prompt,
|
| 112 |
height=h,
|
|
|
|
| 114 |
num_inference_steps=job.req.num_inference_steps,
|
| 115 |
guidance_scale=job.req.guidance_scale,
|
| 116 |
generator=rng,
|
| 117 |
+
)
|
| 118 |
|
| 119 |
pil_image = result["images"][0]
|
| 120 |
buf = io.BytesIO()
|
rknnlcm.py
CHANGED
|
@@ -43,9 +43,10 @@ class RKNN2Model:
|
|
| 43 |
*,
|
| 44 |
core_mask: Optional[Union[str, int]] = None,
|
| 45 |
multi_context: bool = True,
|
| 46 |
-
data_format: str = "
|
| 47 |
verbose_shapes: bool = False,
|
| 48 |
runtime_kwargs: Optional[dict] = None,
|
|
|
|
| 49 |
**_ignored: Any,
|
| 50 |
):
|
| 51 |
"""
|
|
@@ -61,6 +62,8 @@ class RKNN2Model:
|
|
| 61 |
- runtime_kwargs: optional extra kwargs to pass into init_runtime(...)
|
| 62 |
- **_ignored: allows you to pass context_name/worker_id etc without breaking
|
| 63 |
"""
|
|
|
|
|
|
|
| 64 |
self.model_dir = model_dir
|
| 65 |
self.data_format = data_format
|
| 66 |
self.verbose_shapes = verbose_shapes
|
|
@@ -123,23 +126,32 @@ class RKNN2Model:
|
|
| 123 |
raise TypeError(f"core_mask must be None, int, or str; got {type(core_mask)}")
|
| 124 |
|
| 125 |
def __call__(self, **kwargs) -> List[np.ndarray]:
|
| 126 |
-
#
|
| 127 |
-
input_list =
|
| 128 |
-
|
| 129 |
-
if self.verbose_shapes:
|
| 130 |
-
for i, arr in enumerate(input_list):
|
| 131 |
-
if isinstance(arr, np.ndarray):
|
| 132 |
-
logger.info(f"[{self.modelname}] input[{i}] shape={arr.shape} dtype={arr.dtype}")
|
| 133 |
-
|
| 134 |
results = self.rknnlite.inference(inputs=input_list, data_format=self.data_format)
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
if isinstance(res, np.ndarray):
|
| 139 |
-
logger.info(f"[{self.modelname}] output[{j}] shape={res.shape} dtype={res.dtype}")
|
| 140 |
|
| 141 |
return results
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
class RKNN2LatentConsistencyPipeline(DiffusionPipeline):
|
| 144 |
|
| 145 |
def __init__(
|
|
@@ -379,6 +391,11 @@ class RKNN2LatentConsistencyPipeline(DiffusionPipeline):
|
|
| 379 |
f" {negative_prompt_embeds.shape}."
|
| 380 |
)
|
| 381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 383 |
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
| 384 |
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
@@ -540,6 +557,7 @@ class RKNN2LatentConsistencyPipeline(DiffusionPipeline):
|
|
| 540 |
timestep_dtype = np.int64
|
| 541 |
|
| 542 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
|
|
| 543 |
inference_start = time.time()
|
| 544 |
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 545 |
timestep = np.array([t], dtype=timestep_dtype)
|
|
@@ -654,9 +672,9 @@ def generate_png_bytes(args):
|
|
| 654 |
user_specified_scheduler = LCMScheduler.from_config(scheduler_config)
|
| 655 |
|
| 656 |
pipe = RKNN2LatentConsistencyPipeline(
|
| 657 |
-
text_encoder=RKNN2Model(os.path.join(args.i, "text_encoder")),
|
| 658 |
-
unet=RKNN2Model(os.path.join(args.i, "unet")),
|
| 659 |
-
vae_decoder=RKNN2Model(os.path.join(args.i, "vae_decoder")),
|
| 660 |
scheduler=user_specified_scheduler,
|
| 661 |
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16"),
|
| 662 |
)
|
|
@@ -680,3 +698,76 @@ def generate_png_bytes(args):
|
|
| 680 |
buf.seek(0)
|
| 681 |
|
| 682 |
return buf.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
*,
|
| 44 |
core_mask: Optional[Union[str, int]] = None,
|
| 45 |
multi_context: bool = True,
|
| 46 |
+
data_format: str = "nhwc",
|
| 47 |
verbose_shapes: bool = False,
|
| 48 |
runtime_kwargs: Optional[dict] = None,
|
| 49 |
+
force_fp32=True,
|
| 50 |
**_ignored: Any,
|
| 51 |
):
|
| 52 |
"""
|
|
|
|
| 62 |
- runtime_kwargs: optional extra kwargs to pass into init_runtime(...)
|
| 63 |
- **_ignored: allows you to pass context_name/worker_id etc without breaking
|
| 64 |
"""
|
| 65 |
+
self.data_format = data_format.lower()
|
| 66 |
+
self.force_fp32 = force_fp32
|
| 67 |
self.model_dir = model_dir
|
| 68 |
self.data_format = data_format
|
| 69 |
self.verbose_shapes = verbose_shapes
|
|
|
|
| 126 |
raise TypeError(f"core_mask must be None, int, or str; got {type(core_mask)}")
|
| 127 |
|
| 128 |
def __call__(self, **kwargs) -> List[np.ndarray]:
|
| 129 |
+
# TODO We need deterministic ordering
|
| 130 |
+
input_list = [self._prep(v) for v in kwargs.values()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
results = self.rknnlite.inference(inputs=input_list, data_format=self.data_format)
|
| 132 |
|
| 133 |
+
logger.info("%s out[0] shape=%s dtype=%s",
|
| 134 |
+
self.modelname, results[0].shape, results[0].dtype)
|
|
|
|
|
|
|
| 135 |
|
| 136 |
return results
|
| 137 |
|
| 138 |
+
def _prep(self, x):
|
| 139 |
+
import numpy as np
|
| 140 |
+
if isinstance(x, np.ndarray):
|
| 141 |
+
# dtype safety
|
| 142 |
+
if self.force_fp32 and x.dtype in (np.float64, np.float16):
|
| 143 |
+
x = x.astype(np.float32, copy=False)
|
| 144 |
+
|
| 145 |
+
# layout safety for 4D tensors
|
| 146 |
+
if x.ndim == 4:
|
| 147 |
+
if self.data_format == "nhwc" and x.shape[1] in (1, 3, 4): # likely NCHW
|
| 148 |
+
x = x.transpose(0, 2, 3, 1)
|
| 149 |
+
elif self.data_format == "nchw" and x.shape[-1] in (1, 3, 4): # likely NHWC
|
| 150 |
+
x = x.transpose(0, 3, 1, 2)
|
| 151 |
+
|
| 152 |
+
x = np.ascontiguousarray(x)
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
class RKNN2LatentConsistencyPipeline(DiffusionPipeline):
|
| 156 |
|
| 157 |
def __init__(
|
|
|
|
| 391 |
f" {negative_prompt_embeds.shape}."
|
| 392 |
)
|
| 393 |
|
| 394 |
+
# Keep latents in NCHW everywhere in Python, and only convert to NHWC right at the RKNN boundary for models that require it.
|
| 395 |
+
#That means:
|
| 396 |
+
#• Before UNet RKNN call: NCHW -> NHWC
|
| 397 |
+
#• After UNet RKNN call: NHWC -> NCHW (only if the raw output is NHWC)
|
| 398 |
+
#• VAE decoder input: if it expects NHWC, convert right before it too.
|
| 399 |
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 400 |
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
| 401 |
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
|
|
| 557 |
timestep_dtype = np.int64
|
| 558 |
|
| 559 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 560 |
+
|
| 561 |
inference_start = time.time()
|
| 562 |
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 563 |
timestep = np.array([t], dtype=timestep_dtype)
|
|
|
|
| 672 |
user_specified_scheduler = LCMScheduler.from_config(scheduler_config)
|
| 673 |
|
| 674 |
pipe = RKNN2LatentConsistencyPipeline(
|
| 675 |
+
text_encoder = RKNN2Model(os.path.join(args.i, "text_encoder"), data_format="nchw"), # probably irrelevant
|
| 676 |
+
unet = RKNN2Model(os.path.join(args.i, "unet"), data_format="nhwc"), # important
|
| 677 |
+
vae_decoder = RKNN2Model(os.path.join(args.i, "vae_decoder"), data_format="nhwc"), # important
|
| 678 |
scheduler=user_specified_scheduler,
|
| 679 |
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16"),
|
| 680 |
)
|
|
|
|
| 698 |
buf.seek(0)
|
| 699 |
|
| 700 |
return buf.getvalue()
|
| 701 |
+
|
| 702 |
+
def main(args):
|
| 703 |
+
logger.info(f"Setting random seed to {args.seed}")
|
| 704 |
+
|
| 705 |
+
# load scheduler from scheduler/scheduler_config.json
|
| 706 |
+
scheduler_config_path = os.path.join(args.i, "scheduler/scheduler_config.json")
|
| 707 |
+
with open(scheduler_config_path, "r") as f:
|
| 708 |
+
scheduler_config = json.load(f)
|
| 709 |
+
user_specified_scheduler = LCMScheduler.from_config(scheduler_config)
|
| 710 |
+
|
| 711 |
+
logger.info("Using scheduler: %s", user_specified_scheduler.__class__.__name__)
|
| 712 |
+
|
| 713 |
+
# Parse size as WIDTHxHEIGHT (common CLI convention)
|
| 714 |
+
w_str, h_str = args.size.lower().split("x")
|
| 715 |
+
width, height = int(w_str), int(h_str)
|
| 716 |
+
|
| 717 |
+
pipe = RKNN2LatentConsistencyPipeline(
|
| 718 |
+
text_encoder=RKNN2Model(os.path.join(args.i, "text_encoder"), data_format="nchw"),
|
| 719 |
+
unet=RKNN2Model(os.path.join(args.i, "unet"), data_format="nhwc"),
|
| 720 |
+
vae_decoder=RKNN2Model(os.path.join(args.i, "vae_decoder"), data_format="nhwc"),
|
| 721 |
+
scheduler=user_specified_scheduler,
|
| 722 |
+
tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16"),
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
logger.info("Beginning image generation.")
|
| 726 |
+
out = pipe(
|
| 727 |
+
prompt=args.prompt,
|
| 728 |
+
height=height,
|
| 729 |
+
width=width,
|
| 730 |
+
num_inference_steps=args.num_inference_steps,
|
| 731 |
+
guidance_scale=args.guidance_scale,
|
| 732 |
+
generator=np.random.RandomState(args.seed),
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
out_path = get_image_path(args)
|
| 736 |
+
logger.info("Saving generated image to %s", out_path)
|
| 737 |
+
out["images"][0].save(out_path)
|
| 738 |
+
|
| 739 |
+
if __name__ == "__main__":
|
| 740 |
+
parser = argparse.ArgumentParser()
|
| 741 |
+
|
| 742 |
+
parser.add_argument(
|
| 743 |
+
"--prompt",
|
| 744 |
+
required=True,
|
| 745 |
+
help="The text prompt to be used for text-to-image generation.")
|
| 746 |
+
parser.add_argument(
|
| 747 |
+
"-i",
|
| 748 |
+
required=True,
|
| 749 |
+
help=("Path to model directory"))
|
| 750 |
+
parser.add_argument("-o", required=True)
|
| 751 |
+
parser.add_argument("--seed",
|
| 752 |
+
default=93,
|
| 753 |
+
type=int,
|
| 754 |
+
help="Random seed to be able to reproduce results")
|
| 755 |
+
parser.add_argument(
|
| 756 |
+
"-s",
|
| 757 |
+
"--size",
|
| 758 |
+
default="256x256",
|
| 759 |
+
type=str,
|
| 760 |
+
help="Image size")
|
| 761 |
+
parser.add_argument(
|
| 762 |
+
"--num-inference-steps",
|
| 763 |
+
default=4,
|
| 764 |
+
type=int,
|
| 765 |
+
help="The number of iterations the unet model will be executed throughout the reverse diffusion process")
|
| 766 |
+
parser.add_argument(
|
| 767 |
+
"--guidance-scale",
|
| 768 |
+
default=7.5,
|
| 769 |
+
type=float,
|
| 770 |
+
help="Controls the influence of the text prompt on sampling process (0=random images)")
|
| 771 |
+
|
| 772 |
+
args = parser.parse_args()
|
| 773 |
+
main(args)
|