| |
|
| |
|
| |
|
| | from typing import List, Union |
| | import numpy as np |
| | import axengine |
| | import torch |
| | from PIL import Image |
| | from transformers import CLIPTokenizer, PreTrainedTokenizer |
| | import time |
| | import argparse |
| | import uuid |
| | import os |
| | import traceback |
| | from diffusers import DPMSolverMultistepScheduler |
| |
|
| | |
| | DEBUG_MODE = True |
| | LOG_TIMESTAMP = True |
| |
|
| | def debug_log(msg): |
| | if DEBUG_MODE: |
| | timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else "" |
| | print(f"{timestamp}[DEBUG] {msg}") |
| |
|
| | def get_args(): |
| | try: |
| | parser = argparse.ArgumentParser( |
| | prog="StableDiffusion", |
| | description="Generate picture with the input prompt using DPM++ sampler" |
| | ) |
| | parser.add_argument("--prompt", type=str, required=False, |
| | default="masterpiece, best quality, 1girl, (colorful),(delicate eyes and face), volumatic light, ray tracing, bust shot ,extremely detailed CG unity 8k wallpaper,solo,smile,intricate skirt,((flying petal)),(Flowery meadow) sky, cloudy_sky, moonlight, moon, night, (dark theme:1.3), light, fantasy, windy, magic sparks, dark castle,white hair", |
| | help="the input text prompt") |
| | parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", |
| | help="Path to text encoder and tokenizer files") |
| | parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", |
| | help="Path to unet axmodel model") |
| | parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", |
| | help="Path to vae decoder axmodel model") |
| | parser.add_argument("--time_input", type=str, required=False, |
| | default="./models/time_input_dpmpp_20steps.npy", |
| | help="Path to time input file") |
| | parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe", |
| | help="Path to the output image file") |
| | parser.add_argument("--num_inference_steps", type=int, default=20, |
| | help="Number of inference steps for DPM++ sampler") |
| | parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale for CFG") |
| | parser.add_argument("--seed", type=int, default=None, help="Random seed") |
| | return parser.parse_args() |
| | except Exception as e: |
| | print(f"参数解析失败: {str(e)}") |
| | traceback.print_exc() |
| | exit(1) |
| |
|
| | def get_embeds(prompt, negative_prompt, tokenizer_dir, text_encoder_dir): |
| | """获取正负提示词的嵌入(带形状验证)""" |
| | try: |
| | debug_log(f"开始处理提示词: {prompt[:50]}...") |
| | start_time = time.time() |
| | |
| | tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir) |
| | |
| | def process_prompt(prompt_text): |
| | inputs = tokenizer( |
| | prompt_text, |
| | padding="max_length", |
| | max_length=77, |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}") |
| | |
| | model_path = os.path.join(text_encoder_dir, "sd15_text_encoder_sim.axmodel") |
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"文本编码器模型不存在: {model_path}") |
| | |
| | session = axengine.InferenceSession(model_path) |
| | outputs = session.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0] |
| | debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}") |
| | return outputs |
| | |
| | neg_start = time.time() |
| | neg_embeds = process_prompt(negative_prompt) |
| | pos_embeds = process_prompt(prompt) |
| | debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s") |
| | |
| | |
| | if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768): |
| | raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}") |
| | |
| | return neg_embeds, pos_embeds |
| | except Exception as e: |
| | print(f"获取嵌入失败: {str(e)}") |
| | traceback.print_exc() |
| | exit(1) |
| |
|
| | def main(): |
| | try: |
| | debug_log("程序启动") |
| | args = get_args() |
| | debug_log(f"参数解析完成 | 随机种子: {args.seed} | 推理步数: {args.num_inference_steps}") |
| | |
| | |
| | seed = args.seed if args.seed else int(time.time()) |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | debug_log(f"随机种子设置完成: {seed}") |
| | |
| | |
| | model_paths = [ |
| | args.unet_model, |
| | args.vae_decoder_model, |
| | os.path.join(args.text_model_dir, 'tokenizer'), |
| | os.path.join(args.text_model_dir, 'text_encoder') |
| | ] |
| | for path in model_paths: |
| | if not os.path.exists(path): |
| | raise FileNotFoundError(f"模型路径不存在: {path}") |
| | |
| | |
| | debug_log("初始化调度器...") |
| | scheduler_start = time.time() |
| | scheduler = DPMSolverMultistepScheduler( |
| | num_train_timesteps=1000, |
| | beta_start=0.00085, |
| | beta_end=0.012, |
| | beta_schedule="scaled_linear", |
| | algorithm_type="dpmsolver++", |
| | use_karras_sigmas=True |
| | ) |
| | scheduler.set_timesteps(args.num_inference_steps) |
| | debug_log(f"调度器初始化完成 | 耗时: {(time.time()-scheduler_start):.2f}s") |
| | |
| | |
| | debug_log("加载NPU模型...") |
| | model_load_start = time.time() |
| | unet_session_main = axengine.InferenceSession(args.unet_model) |
| | vae_decoder = axengine.InferenceSession(args.vae_decoder_model) |
| | debug_log(f"模型加载完成 | 总耗时: {(time.time()-model_load_start):.2f}s") |
| | debug_log(f"UNET输入信息: {[str(inp) for inp in unet_session_main.get_inputs()]}") |
| | debug_log(f"VAE输入信息: {[str(inp) for inp in vae_decoder.get_inputs()]}") |
| | |
| | |
| | embed_start = time.time() |
| | neg_embeds, pos_embeds = get_embeds( |
| | args.prompt, |
| | "sketch, duplicate, ugly...", |
| | os.path.join(args.text_model_dir, 'tokenizer'), |
| | os.path.join(args.text_model_dir, 'text_encoder') |
| | ) |
| | debug_log(f"提示词处理完成 | 总耗时: {(time.time()-embed_start):.2f}s") |
| | |
| | |
| | latent_start = time.time() |
| | latents_shape = [1, 4, 60, 40] |
| | generator = torch.Generator().manual_seed(seed) |
| | latent = torch.randn(latents_shape, generator=generator) |
| | init_scale = scheduler.init_noise_sigma |
| | latent = latent * init_scale |
| | debug_log(f"潜在变量初始化 | 形状: {latent.shape} | 缩放系数: {init_scale}") |
| | latent = latent.numpy().astype(np.float32) |
| | debug_log(f"潜在变量转换完成 | dtype: {latent.dtype}") |
| | |
| | |
| | debug_log(f"加载时间嵌入: {args.time_input}") |
| | time_data = np.load(args.time_input) |
| | if len(time_data) < args.num_inference_steps: |
| | raise ValueError(f"时间嵌入不足: 需要{args.num_inference_steps}, 实际{len(time_data)}") |
| | time_data = time_data[:args.num_inference_steps] |
| | debug_log(f"时间嵌入验证通过 | 形状: {time_data.shape}") |
| | |
| | |
| | debug_log("开始采样循环...") |
| | total_unet_time = 0 |
| | for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)): |
| | step_start = time.time() |
| | debug_log(f"\n--- 步骤 {step_idx+1}/{args.num_inference_steps} [ts={timestep}] ---") |
| | |
| | try: |
| | |
| | if np.isnan(latent).any(): |
| | raise ValueError("潜在变量包含NaN值!") |
| | |
| | |
| | time_emb = np.expand_dims(time_data[step_idx], axis=0) |
| | debug_log(f"时间嵌入形状: {time_emb.shape}") |
| | |
| | |
| | debug_log("运行UNET(负面提示)...") |
| | unet_neg_start = time.time() |
| | noise_pred_neg = unet_session_main.run(None, { |
| | "sample": latent, |
| | "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
| | "encoder_hidden_states": neg_embeds |
| | })[0] |
| | debug_log(f"UNET(负面)完成 | 形状: {noise_pred_neg.shape} | 耗时: {(time.time()-unet_neg_start):.2f}s") |
| | |
| | debug_log("运行UNET(正面提示)...") |
| | unet_pos_start = time.time() |
| | noise_pred_pos = unet_session_main.run(None, { |
| | "sample": latent, |
| | "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
| | "encoder_hidden_states": pos_embeds |
| | })[0] |
| | debug_log(f"UNET(正面)完成 | 耗时: {(time.time()-unet_pos_start):.2f}s") |
| | |
| | |
| | debug_log(f"应用CFG指导(scale={args.guidance_scale})...") |
| | noise_pred = noise_pred_neg + args.guidance_scale * (noise_pred_pos - noise_pred_neg) |
| | debug_log(f"噪声预测范围: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]") |
| | |
| | |
| | latent_tensor = torch.from_numpy(latent) |
| | noise_pred_tensor = torch.from_numpy(noise_pred) |
| | |
| | |
| | debug_log("更新潜在变量...") |
| | scheduler_start = time.time() |
| | latent_tensor = scheduler.step( |
| | model_output=noise_pred_tensor, |
| | timestep=timestep, |
| | sample=latent_tensor |
| | ).prev_sample |
| | debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s") |
| | |
| | |
| | latent = latent_tensor.numpy().astype(np.float32) |
| | debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]") |
| | |
| | step_time = time.time() - step_start |
| | total_unet_time += step_time |
| | debug_log(f"步骤完成 | 单步耗时: {step_time:.2f}s | 累计耗时: {total_unet_time:.2f}s") |
| | |
| | except Exception as e: |
| | print(f"步骤 {step_idx+1} 执行失败: {str(e)}") |
| | traceback.print_exc() |
| | exit(1) |
| | |
| | |
| | debug_log("\n开始VAE解码...") |
| | vae_start = time.time() |
| | try: |
| | latent = latent / 0.18215 |
| | debug_log(f"VAE输入范围: [{latent.min():.3f}, {latent.max():.3f}]") |
| | image = vae_decoder.run(None, {"latent": latent})[0] |
| | debug_log(f"VAE输出形状: {image.shape} | 耗时: {(time.time()-vae_start):.2f}s") |
| | except Exception as e: |
| | print(f"VAE解码失败: {str(e)}") |
| | traceback.print_exc() |
| | exit(1) |
| | |
| | |
| | debug_log("保存结果...") |
| | try: |
| | image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0) |
| | image_denorm = np.clip(image / 2 + 0.5, 0, 1) |
| | image = (image_denorm * 255).round().astype("uint8") |
| | debug_log(f"图像形状: {image.shape} | dtype: {image.dtype}") |
| | |
| | pil_image = Image.fromarray(image[:, :, :3]) |
| | save_path = os.path.join(args.save_dir, f"{uuid.uuid4()}.png") |
| | pil_image.save(save_path) |
| | debug_log(f"图像保存成功: {save_path}") |
| | except Exception as e: |
| | print(f"保存失败: {str(e)}") |
| | traceback.print_exc() |
| | exit(1) |
| | |
| | except Exception as e: |
| | print(f"主流程执行失败: {str(e)}") |
| | traceback.print_exc() |
| | exit(1) |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|