|
|
|
|
|
"""收集所有子ONNX模型的输入数据,用于量化校准。 |
|
|
|
|
|
此脚本通过推理完整的原始ONNX模型来收集各个子图的输入数据。 |
|
|
会为每个子图创建独立的文件夹,保存输入数据,并打包成tar文件。 |
|
|
""" |
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import tarfile |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
import torch |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
|
|
|
current_file_path = os.path.abspath(__file__) |
|
|
project_roots = [ |
|
|
os.path.dirname(current_file_path), |
|
|
os.path.dirname(os.path.dirname(current_file_path)), |
|
|
os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))), |
|
|
] |
|
|
for project_root in project_roots: |
|
|
if project_root not in sys.path: |
|
|
sys.path.insert(0, project_root) |
|
|
repo_root = project_roots[-1] |
|
|
|
|
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
|
from diffusers.image_processor import VaeImageProcessor |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from videox_fun.models import ( |
|
|
AutoTokenizer, |
|
|
AutoencoderKL, |
|
|
Qwen3ForCausalLM, |
|
|
ZImageTransformer2DModel, |
|
|
) |
|
|
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
|
|
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
|
|
|
|
|
|
|
|
|
config_path_default = "config/z_image/z_image.yaml" |
|
|
model_name = "models/Diffusion_Transformer/Z-Image-Turbo/" |
|
|
|
|
|
|
|
|
ORIGINAL_ONNX = os.path.join(repo_root, "onnx-models/z_image_transformer_body_only_simp_slim.onnx") |
|
|
SUBGRAPH_CONFIG = os.path.join(repo_root, "pulsar2_configs/transformers_subgraph.json") |
|
|
|
|
|
|
|
|
PROMPTS = [ |
|
|
"(masterpiece, best quality) solo female on a tropical beach, golden hour rim light, cinematic grading", |
|
|
"nighttime cyberpunk boulevard, neon reflections on wet asphalt, volumetric fog, wide shot", |
|
|
"sunrise over alpine mountains, low clouds in valleys, god rays, ultra-detailed landscape", |
|
|
"modern minimal living room, soft natural light, Scandinavian design, high-resolution interior render", |
|
|
] |
|
|
|
|
|
|
|
|
OUTPUT_BASE_DIR = os.path.join(repo_root, "onnx-calibration-subgraphs") |
|
|
TAR_LIST_FILE = os.path.join(OUTPUT_BASE_DIR, "subgraph_calibration_paths.txt") |
|
|
|
|
|
|
|
|
sample_size = [512, 512] |
|
|
num_inference_steps = 9 |
|
|
seed = 42 |
|
|
sampler_name = "Flow" |
|
|
vae_scale_factor = 8 |
|
|
vae_scale = vae_scale_factor * 2 |
|
|
max_sequence_length = 128 |
|
|
|
|
|
|
|
|
def _select_weight_dtype(device: torch.device) -> torch.dtype: |
|
|
if device.type == "cuda": |
|
|
if torch.cuda.is_bf16_supported(): |
|
|
return torch.bfloat16 |
|
|
return torch.float16 |
|
|
return torch.float32 |
|
|
|
|
|
|
|
|
SCHEDULER_MAP = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
"Flow_Unipc": FlowUniPCMultistepScheduler, |
|
|
"Flow_DPM++": FlowDPMSolverMultistepScheduler, |
|
|
} |
|
|
|
|
|
|
|
|
def _resolve_config_path(path: str) -> Optional[str]: |
|
|
candidate = path if os.path.isabs(path) else os.path.join(repo_root, path) |
|
|
return candidate if os.path.exists(candidate) else None |
|
|
|
|
|
|
|
|
def _infer_module_device(module: torch.nn.Module) -> torch.device: |
|
|
param = next(module.parameters(), None) |
|
|
if param is not None: |
|
|
return param.device |
|
|
buffer = next(module.buffers(), None) |
|
|
if buffer is not None: |
|
|
return buffer.device |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
def _encode_prompt( |
|
|
prompt: Union[str, List[str]], |
|
|
device: Optional[torch.device], |
|
|
tokenizer: AutoTokenizer, |
|
|
text_encoder: Qwen3ForCausalLM, |
|
|
max_sequence_length: int = 512, |
|
|
) -> List[torch.FloatTensor]: |
|
|
device = device or _infer_module_device(text_encoder) |
|
|
if isinstance(prompt, str): |
|
|
prompt = [prompt] |
|
|
|
|
|
for i, prompt_item in enumerate(prompt): |
|
|
messages = [{"role": "user", "content": prompt_item}] |
|
|
prompt_item = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
enable_thinking=True, |
|
|
) |
|
|
prompt[i] = prompt_item |
|
|
|
|
|
text_inputs = tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=max_sequence_length, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
text_input_ids = text_inputs.input_ids.to(device) |
|
|
prompt_masks = text_inputs.attention_mask.to(device).bool() |
|
|
|
|
|
with torch.no_grad(): |
|
|
prompt_embeds = text_encoder( |
|
|
input_ids=text_input_ids, |
|
|
attention_mask=prompt_masks, |
|
|
output_hidden_states=True, |
|
|
).hidden_states[-2] |
|
|
|
|
|
embeddings_list = [] |
|
|
for i in range(len(prompt_embeds)): |
|
|
embeddings_list.append(prompt_embeds[i]) |
|
|
return embeddings_list |
|
|
|
|
|
|
|
|
def encode_prompt( |
|
|
prompt: Union[str, List[str]], |
|
|
device: Optional[torch.device], |
|
|
tokenizer: AutoTokenizer, |
|
|
text_encoder: Qwen3ForCausalLM, |
|
|
do_classifier_free_guidance: bool = False, |
|
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
|
max_sequence_length: int = 512, |
|
|
): |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
prompt_embeds = _encode_prompt( |
|
|
prompt=prompt, |
|
|
device=device, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
max_sequence_length=max_sequence_length, |
|
|
) |
|
|
|
|
|
if do_classifier_free_guidance: |
|
|
if negative_prompt is None: |
|
|
negative_prompt = ["" for _ in prompt] |
|
|
else: |
|
|
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
negative_prompt_embeds = _encode_prompt( |
|
|
prompt=negative_prompt, |
|
|
device=device, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
max_sequence_length=max_sequence_length, |
|
|
) |
|
|
else: |
|
|
negative_prompt_embeds = [] |
|
|
return prompt_embeds, negative_prompt_embeds |
|
|
|
|
|
|
|
|
def _stack_prompt_embeddings(prompt_embeds_input): |
|
|
if isinstance(prompt_embeds_input, list): |
|
|
return torch.stack(prompt_embeds_input, dim=0) |
|
|
return prompt_embeds_input |
|
|
|
|
|
|
|
|
def prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator): |
|
|
height = 2 * (int(height) // (vae_scale_factor * 2)) |
|
|
width = 2 * (int(width) // (vae_scale_factor * 2)) |
|
|
shape = (batch_size, num_channels_latents, height, width) |
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
return latents |
|
|
|
|
|
|
|
|
def calculate_shift( |
|
|
image_seq_len, |
|
|
base_seq_len: int = 256, |
|
|
max_seq_len: int = 4096, |
|
|
base_shift: float = 0.5, |
|
|
max_shift: float = 1.15, |
|
|
): |
|
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
|
|
b = base_shift - m * base_seq_len |
|
|
mu = image_seq_len * m + b |
|
|
return mu |
|
|
|
|
|
|
|
|
def retrieve_timesteps(scheduler, num_inference_steps: Optional[int] = None, device=None, **kwargs): |
|
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
|
timesteps = scheduler.timesteps |
|
|
return timesteps, len(timesteps) |
|
|
|
|
|
|
|
|
def _to_float32_np(tensor: torch.Tensor) -> np.ndarray: |
|
|
return tensor.detach().to(dtype=torch.float32, device="cpu").numpy() |
|
|
|
|
|
|
|
|
def load_subgraph_config(config_path: str) -> List[dict]: |
|
|
"""加载子图配置""" |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
sub_configs = config.get("compiler", {}).get("sub_configs", []) |
|
|
return sub_configs |
|
|
|
|
|
|
|
|
def add_intermediate_outputs_to_onnx(onnx_path: str, output_tensor_names: List[str], output_path: str): |
|
|
"""为ONNX模型添加中间输出,以便提取中间张量值""" |
|
|
import onnx |
|
|
|
|
|
print(f"加载ONNX模型: {onnx_path}") |
|
|
model = onnx.load(onnx_path, load_external_data=False) |
|
|
|
|
|
|
|
|
existing_outputs = {out.name for out in model.graph.output} |
|
|
|
|
|
|
|
|
added_count = 0 |
|
|
for tensor_name in output_tensor_names: |
|
|
if tensor_name not in existing_outputs: |
|
|
|
|
|
value_info = None |
|
|
for vi in model.graph.value_info: |
|
|
if vi.name == tensor_name: |
|
|
value_info = vi |
|
|
break |
|
|
|
|
|
if value_info is None: |
|
|
|
|
|
from onnx import TensorProto |
|
|
value_info = onnx.helper.make_tensor_value_info( |
|
|
tensor_name, TensorProto.UNDEFINED, None |
|
|
) |
|
|
|
|
|
|
|
|
output_vi = onnx.ValueInfoProto() |
|
|
output_vi.CopyFrom(value_info) |
|
|
model.graph.output.append(output_vi) |
|
|
added_count += 1 |
|
|
|
|
|
if added_count > 0: |
|
|
print(f"添加了 {added_count} 个中间输出") |
|
|
onnx.save(model, output_path, save_as_external_data=False) |
|
|
print(f"修改后的模型已保存到: {output_path}") |
|
|
return output_path |
|
|
else: |
|
|
print("所有需要的输出已存在,无需修改模型") |
|
|
return onnx_path |
|
|
|
|
|
|
|
|
def create_onnx_session(onnx_path: str, providers: List[str]): |
|
|
"""创建ONNX推理会话""" |
|
|
sess_options = ort.SessionOptions() |
|
|
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL |
|
|
|
|
|
|
|
|
session = ort.InferenceSession(onnx_path, sess_options, providers=providers) |
|
|
|
|
|
return session |
|
|
|
|
|
|
|
|
def run_inference_and_collect( |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
transformer, |
|
|
scheduler, |
|
|
device, |
|
|
weight_dtype, |
|
|
sub_configs: List[dict], |
|
|
onnx_session, |
|
|
output_base_dir: str, |
|
|
skip_existing: bool = False, |
|
|
): |
|
|
"""运行推理并收集所有子图的输入数据""" |
|
|
|
|
|
|
|
|
existing_data = check_existing_data(output_base_dir) if skip_existing else {} |
|
|
|
|
|
|
|
|
subgraph_data = {} |
|
|
need_inference = False |
|
|
|
|
|
for idx, config in enumerate(sub_configs): |
|
|
subgraph_label = f"cfg_{idx:02d}" |
|
|
|
|
|
|
|
|
if skip_existing and existing_data.get(subgraph_label, False): |
|
|
loaded_data = load_existing_subgraph_data(subgraph_label, output_base_dir) |
|
|
if loaded_data: |
|
|
subgraph_data[subgraph_label] = loaded_data |
|
|
continue |
|
|
|
|
|
|
|
|
subgraph_data[subgraph_label] = { |
|
|
"start_tensors": config["start_tensor_names"], |
|
|
"end_tensors": config["end_tensor_names"], |
|
|
"samples": [] |
|
|
} |
|
|
need_inference = True |
|
|
|
|
|
|
|
|
if skip_existing and existing_data.get("auto_00", False): |
|
|
loaded_data = load_existing_subgraph_data("auto_00", output_base_dir) |
|
|
if loaded_data: |
|
|
subgraph_data["auto_00"] = loaded_data |
|
|
else: |
|
|
subgraph_data["auto_00"] = { |
|
|
"start_tensors": ["/model/layers.29/Add_4_output_0", "/model/t_embedder/mlp/mlp.2/Gemm_output_0"], |
|
|
"end_tensors": ["sample"], |
|
|
"samples": [] |
|
|
} |
|
|
need_inference = True |
|
|
else: |
|
|
subgraph_data["auto_00"] = { |
|
|
"start_tensors": ["/model/layers.29/Add_4_output_0", "/model/t_embedder/mlp/mlp.2/Gemm_output_0"], |
|
|
"end_tensors": ["sample"], |
|
|
"samples": [] |
|
|
} |
|
|
need_inference = True |
|
|
|
|
|
|
|
|
if not need_inference: |
|
|
print("\n所有子图数据都已存在,跳过推理过程") |
|
|
return subgraph_data |
|
|
|
|
|
height, width = sample_size |
|
|
num_channels_latents = 16 |
|
|
|
|
|
|
|
|
onnx_inputs = {inp.name: inp for inp in onnx_session.get_inputs()} |
|
|
onnx_outputs = {out.name: out for out in onnx_session.get_outputs()} |
|
|
|
|
|
print(f"ONNX模型输入: {list(onnx_inputs.keys())}") |
|
|
print(f"ONNX模型输出数量: {len(onnx_outputs)}") |
|
|
print(f"收集子图数量: {len(subgraph_data)} (包括 auto_00)") |
|
|
|
|
|
for prompt_idx, prompt in enumerate(PROMPTS): |
|
|
print(f"\n处理 Prompt {prompt_idx + 1}/{len(PROMPTS)}: {prompt[:60]}...") |
|
|
|
|
|
|
|
|
prompt_embeds, _ = encode_prompt( |
|
|
prompt=prompt, |
|
|
negative_prompt="", |
|
|
do_classifier_free_guidance=False, |
|
|
device=device, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
max_sequence_length=max_sequence_length, |
|
|
) |
|
|
|
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed + prompt_idx) |
|
|
latents = prepare_latents( |
|
|
batch_size=1, |
|
|
num_channels_latents=num_channels_latents, |
|
|
height=height, |
|
|
width=width, |
|
|
dtype=torch.float32, |
|
|
device=device, |
|
|
generator=generator, |
|
|
) |
|
|
|
|
|
prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds) |
|
|
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) |
|
|
mu = calculate_shift( |
|
|
image_seq_len, |
|
|
scheduler.config.get("base_image_seq_len", 256), |
|
|
scheduler.config.get("max_image_seq_len", 4096), |
|
|
scheduler.config.get("base_shift", 0.5), |
|
|
scheduler.config.get("max_shift", 1.15), |
|
|
) |
|
|
scheduler.sigma_min = 0.0 |
|
|
timesteps, _ = retrieve_timesteps( |
|
|
scheduler, |
|
|
num_inference_steps=num_inference_steps, |
|
|
device=device, |
|
|
mu=mu, |
|
|
) |
|
|
|
|
|
|
|
|
for step_idx, t in enumerate(timesteps): |
|
|
print(f" Step {step_idx + 1}/{len(timesteps)}") |
|
|
|
|
|
|
|
|
if latents.dim() == 5: |
|
|
print(f" 警告: latents是5维,squeeze到4维") |
|
|
latents = latents.squeeze(2) |
|
|
|
|
|
timestep = t.expand(latents.shape[0]) |
|
|
timestep_model_input = (1000 - timestep) / 1000 |
|
|
|
|
|
latent_model_input = latents |
|
|
|
|
|
latent_for_onnx = latent_model_input.unsqueeze(2) |
|
|
prompt_for_onnx = prompt_embeds_tensor |
|
|
|
|
|
|
|
|
onnx_feed = { |
|
|
"latent_model_input": _to_float32_np(latent_for_onnx), |
|
|
"timestep": _to_float32_np(timestep_model_input), |
|
|
"prompt_embeds": _to_float32_np(prompt_for_onnx), |
|
|
} |
|
|
|
|
|
if step_idx == 0: |
|
|
print(f" 调试: latents原始 shape = {latents.shape}") |
|
|
print(f" 调试: latent_for_onnx shape = {latent_for_onnx.shape}") |
|
|
print(f" 调试: timestep shape = {timestep_model_input.shape}") |
|
|
print(f" 调试: prompt shape = {prompt_for_onnx.shape}") |
|
|
|
|
|
|
|
|
try: |
|
|
onnx_results = onnx_session.run(None, onnx_feed) |
|
|
onnx_output_dict = {out.name: onnx_results[i] for i, out in enumerate(onnx_session.get_outputs())} |
|
|
|
|
|
|
|
|
sample_id = f"prompt{prompt_idx:03d}_step{step_idx:02d}" |
|
|
|
|
|
|
|
|
for idx, config in enumerate(sub_configs): |
|
|
subgraph_label = f"cfg_{idx:02d}" |
|
|
start_tensors = config["start_tensor_names"] |
|
|
|
|
|
|
|
|
inputs_dict = {} |
|
|
for tensor_name in start_tensors: |
|
|
if tensor_name in onnx_feed: |
|
|
|
|
|
inputs_dict[tensor_name] = onnx_feed[tensor_name] |
|
|
elif tensor_name in onnx_output_dict: |
|
|
|
|
|
inputs_dict[tensor_name] = onnx_output_dict[tensor_name] |
|
|
else: |
|
|
print(f" 警告: 子图 {subgraph_label} 缺少输入张量: {tensor_name}") |
|
|
|
|
|
if inputs_dict: |
|
|
subgraph_data[subgraph_label]["samples"].append({ |
|
|
"id": sample_id, |
|
|
"inputs": inputs_dict |
|
|
}) |
|
|
|
|
|
|
|
|
auto_start_tensors = subgraph_data["auto_00"]["start_tensors"] |
|
|
auto_inputs_dict = {} |
|
|
for tensor_name in auto_start_tensors: |
|
|
if tensor_name in onnx_feed: |
|
|
auto_inputs_dict[tensor_name] = onnx_feed[tensor_name] |
|
|
elif tensor_name in onnx_output_dict: |
|
|
auto_inputs_dict[tensor_name] = onnx_output_dict[tensor_name] |
|
|
else: |
|
|
print(f" 警告: 子图 auto_00 缺少输入张量: {tensor_name}") |
|
|
|
|
|
if auto_inputs_dict: |
|
|
subgraph_data["auto_00"]["samples"].append({ |
|
|
"id": sample_id, |
|
|
"inputs": auto_inputs_dict |
|
|
}) |
|
|
|
|
|
|
|
|
if "sample" in onnx_output_dict: |
|
|
noise_pred = torch.from_numpy(onnx_output_dict["sample"]).to(device=device, dtype=torch.float32) |
|
|
|
|
|
if noise_pred.dim() == 5: |
|
|
noise_pred = noise_pred.squeeze(2) |
|
|
|
|
|
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
|
latents = latents.to(device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
if latents.dim() == 5: |
|
|
latents = latents.squeeze(2) |
|
|
else: |
|
|
print(" 警告: ONNX输出中没有'sample',跳过latents更新") |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
print(f" 错误: ONNX推理失败: {e}") |
|
|
print(" 回退到PyTorch推理...") |
|
|
|
|
|
|
|
|
timestep_model_input_device = timestep_model_input.to(device=device, dtype=transformer.dtype) |
|
|
|
|
|
|
|
|
latent_5d = latent_for_onnx.to(transformer.dtype) |
|
|
|
|
|
|
|
|
latent_model_input_list = list(latent_5d.unbind(dim=0)) |
|
|
|
|
|
if device.type == "cuda": |
|
|
with torch.autocast(device_type="cuda", dtype=transformer.dtype): |
|
|
model_out_list = transformer( |
|
|
latent_model_input_list, |
|
|
timestep_model_input_device, |
|
|
prompt_embeds, |
|
|
patch_size=2, |
|
|
f_patch_size=1, |
|
|
)[0] |
|
|
else: |
|
|
model_out_list = transformer( |
|
|
latent_model_input_list, |
|
|
timestep_model_input_device, |
|
|
prompt_embeds, |
|
|
patch_size=2, |
|
|
f_patch_size=1, |
|
|
)[0] |
|
|
|
|
|
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) |
|
|
|
|
|
if noise_pred.dim() == 5: |
|
|
noise_pred = noise_pred.squeeze(2) |
|
|
noise_pred = -noise_pred |
|
|
|
|
|
|
|
|
sample_id = f"prompt{prompt_idx:03d}_step{step_idx:02d}" |
|
|
for idx in range(min(3, len(sub_configs))): |
|
|
subgraph_label = f"cfg_{idx:02d}" |
|
|
config = sub_configs[idx] |
|
|
|
|
|
inputs_dict = {} |
|
|
for tensor_name in config["start_tensor_names"]: |
|
|
if tensor_name == "timestep": |
|
|
inputs_dict[tensor_name] = _to_float32_np(timestep_model_input) |
|
|
elif tensor_name == "prompt_embeds": |
|
|
inputs_dict[tensor_name] = _to_float32_np(prompt_for_onnx) |
|
|
elif tensor_name == "latent_model_input": |
|
|
inputs_dict[tensor_name] = _to_float32_np(latent_for_onnx) |
|
|
|
|
|
if inputs_dict: |
|
|
subgraph_data[subgraph_label]["samples"].append({ |
|
|
"id": sample_id, |
|
|
"inputs": inputs_dict |
|
|
}) |
|
|
|
|
|
latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] |
|
|
latents = latents.to(device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
if latents.dim() == 5: |
|
|
latents = latents.squeeze(2) |
|
|
|
|
|
return subgraph_data |
|
|
|
|
|
|
|
|
def check_existing_data(output_base_dir: str) -> Dict[str, bool]: |
|
|
"""检查哪些子图的tar文件已经存在 |
|
|
|
|
|
Returns: |
|
|
Dict[str, bool]: 子图标签到是否存在的映射 |
|
|
""" |
|
|
existing = {} |
|
|
output_dir = Path(output_base_dir) |
|
|
if not output_dir.exists(): |
|
|
return existing |
|
|
|
|
|
|
|
|
for i in range(33): |
|
|
subgraph_label = f"cfg_{i:02d}" |
|
|
tar_path = output_dir / f"{subgraph_label}.tar" |
|
|
existing[subgraph_label] = tar_path.exists() |
|
|
|
|
|
|
|
|
auto_tar = output_dir / "auto_00.tar" |
|
|
existing["auto_00"] = auto_tar.exists() |
|
|
|
|
|
return existing |
|
|
|
|
|
|
|
|
def load_existing_subgraph_data(subgraph_label: str, output_base_dir: str) -> Optional[Dict]: |
|
|
"""从已存在的tar文件加载子图数据 |
|
|
|
|
|
Returns: |
|
|
Dict or None: 子图数据结构,如果加载失败返回None |
|
|
""" |
|
|
tar_path = Path(output_base_dir) / f"{subgraph_label}.tar" |
|
|
if not tar_path.exists(): |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
sample_count = 0 |
|
|
with tarfile.open(tar_path, "r") as tar: |
|
|
members = tar.getmembers() |
|
|
|
|
|
sample_count = sum(1 for m in members if m.name.endswith('.npy')) |
|
|
|
|
|
print(f" 从已存在的tar加载: {subgraph_label} ({sample_count} 样本)") |
|
|
return { |
|
|
"loaded_from_existing": True, |
|
|
"sample_count": sample_count, |
|
|
"tar_path": str(tar_path.absolute()), |
|
|
"samples": [] |
|
|
} |
|
|
except Exception as e: |
|
|
print(f" 警告: 加载已存在的数据失败 ({subgraph_label}): {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def save_subgraph_data(subgraph_data: Dict, output_base_dir: str, skip_existing: bool = False) -> List[str]: |
|
|
"""保存子图数据并创建tar文件""" |
|
|
os.makedirs(output_base_dir, exist_ok=True) |
|
|
tar_paths = [] |
|
|
|
|
|
for subgraph_label, data in subgraph_data.items(): |
|
|
|
|
|
if data.get("loaded_from_existing", False): |
|
|
tar_path = data.get("tar_path") |
|
|
if tar_path and os.path.exists(tar_path): |
|
|
tar_paths.append(tar_path) |
|
|
print(f" 跳过已存在: {subgraph_label} ({data.get('sample_count', 0)} 样本)") |
|
|
continue |
|
|
|
|
|
print(f"\n保存子图数据: {subgraph_label}") |
|
|
|
|
|
|
|
|
subgraph_dir = os.path.join(output_base_dir, subgraph_label) |
|
|
os.makedirs(subgraph_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
for sample in data["samples"]: |
|
|
sample_file = os.path.join(subgraph_dir, f"{sample['id']}.npy") |
|
|
np.save(sample_file, sample["inputs"]) |
|
|
|
|
|
|
|
|
tar_path = os.path.join(output_base_dir, f"{subgraph_label}.tar") |
|
|
with tarfile.open(tar_path, "w") as tar: |
|
|
tar.add(subgraph_dir, arcname=subgraph_label) |
|
|
|
|
|
tar_paths.append(os.path.abspath(tar_path)) |
|
|
print(f" 已创建: {tar_path} (包含 {len(data['samples'])} 个样本)") |
|
|
|
|
|
return tar_paths |
|
|
|
|
|
|
|
|
def write_tar_list(tar_paths: List[str], output_file: str): |
|
|
"""将所有tar文件路径写入文本文件""" |
|
|
with open(output_file, 'w') as f: |
|
|
for path in tar_paths: |
|
|
f.write(path + '\n') |
|
|
print(f"\n所有tar文件路径已写入: {output_file}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
global sample_size, num_inference_steps, seed, sampler_name, vae_scale_factor, vae_scale, max_sequence_length |
|
|
|
|
|
parser = argparse.ArgumentParser(description="收集子图ONNX模型的量化校准数据") |
|
|
parser.add_argument( |
|
|
"--skip-existing", |
|
|
action="store_true", |
|
|
help="跳过已经存在的tar文件,不重新生成数据" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--onnx", |
|
|
default=ORIGINAL_ONNX, |
|
|
help="原始ONNX模型路径,默认使用项目内置模型" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--subgraph-config", |
|
|
dest="subgraph_config", |
|
|
default=SUBGRAPH_CONFIG, |
|
|
help="子图配置json路径,默认使用项目内置配置" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
dest="output_base_dir", |
|
|
default=OUTPUT_BASE_DIR, |
|
|
help="输出目录,存放子图输入数据及tar文件" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tar-list-file", |
|
|
default=None, |
|
|
help="tar列表文件路径,未提供时默认写入输出目录下的subgraph_calibration_paths.txt" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sample-size", |
|
|
nargs=2, |
|
|
type=int, |
|
|
metavar=("H", "W"), |
|
|
default=sample_size, |
|
|
help="推理分辨率,格式: H W,默认 512 512", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-inference-steps", |
|
|
type=int, |
|
|
default=num_inference_steps, |
|
|
help="推理步数,默认 9", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=seed, |
|
|
help="随机种子,默认 42", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sampler", |
|
|
dest="sampler_name", |
|
|
choices=sorted(SCHEDULER_MAP.keys()), |
|
|
default=sampler_name, |
|
|
help="采样器名称,默认 Flow", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--vae-scale-factor", |
|
|
type=int, |
|
|
default=vae_scale_factor, |
|
|
help="VAE 下采样因子,默认 8", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-seq-len", |
|
|
type=int, |
|
|
default=max_sequence_length, |
|
|
help="最大序列长度,默认 128", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("=" * 80) |
|
|
print("收集子图ONNX模型的量化校准数据") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
sample_size = args.sample_size |
|
|
num_inference_steps = args.num_inference_steps |
|
|
seed = args.seed |
|
|
sampler_name = args.sampler_name |
|
|
vae_scale_factor = args.vae_scale_factor |
|
|
vae_scale = vae_scale_factor * 2 |
|
|
max_sequence_length = args.max_seq_len |
|
|
|
|
|
if args.skip_existing: |
|
|
print("模式: 跳过已存在的数据") |
|
|
else: |
|
|
print("模式: 重新生成所有数据") |
|
|
|
|
|
original_onnx = os.path.expanduser(args.onnx) |
|
|
subgraph_config_path = os.path.expanduser(args.subgraph_config) |
|
|
output_base_dir = os.path.expanduser(args.output_base_dir) |
|
|
tar_list_file = os.path.expanduser(args.tar_list_file) if args.tar_list_file else os.path.join( |
|
|
output_base_dir, |
|
|
os.path.basename(TAR_LIST_FILE), |
|
|
) |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
weight_dtype = _select_weight_dtype(device) |
|
|
|
|
|
print(f"\n设备: {device}") |
|
|
print(f"数据类型: {weight_dtype}") |
|
|
|
|
|
|
|
|
print(f"\n加载子图配置: {subgraph_config_path}") |
|
|
sub_configs = load_subgraph_config(subgraph_config_path) |
|
|
print(f"找到 {len(sub_configs)} 个子图配置") |
|
|
|
|
|
|
|
|
all_tensor_names = set() |
|
|
for config in sub_configs: |
|
|
all_tensor_names.update(config["start_tensor_names"]) |
|
|
all_tensor_names.update(config["end_tensor_names"]) |
|
|
print(f"需要提取 {len(all_tensor_names)} 个张量") |
|
|
|
|
|
|
|
|
modified_onnx_path = original_onnx.replace(".onnx", "_with_outputs.onnx") |
|
|
|
|
|
|
|
|
print(f"\n加载PyTorch模型: {model_name}") |
|
|
config_path = _resolve_config_path(config_path_default) |
|
|
config = OmegaConf.load(config_path) if config_path else None |
|
|
extra_kwargs = {} |
|
|
if config is not None and hasattr(config, "transformer_additional_kwargs"): |
|
|
extra_kwargs = OmegaConf.to_container(config.transformer_additional_kwargs, resolve=True) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="tokenizer") |
|
|
text_encoder = Qwen3ForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
subfolder="text_encoder", |
|
|
torch_dtype=weight_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
).to(device=device, dtype=weight_dtype) |
|
|
text_encoder.eval() |
|
|
|
|
|
transformer = ZImageTransformer2DModel.from_pretrained( |
|
|
model_name, |
|
|
subfolder="transformer", |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=weight_dtype, |
|
|
**({"transformer_additional_kwargs": extra_kwargs} if extra_kwargs else {}), |
|
|
).to(device=device, dtype=weight_dtype) |
|
|
transformer.eval() |
|
|
|
|
|
scheduler_cls = SCHEDULER_MAP.get(sampler_name) |
|
|
if scheduler_cls is None: |
|
|
raise ValueError(f"不支持的采样器: {sampler_name}") |
|
|
scheduler = scheduler_cls.from_pretrained(model_name, subfolder="scheduler") |
|
|
|
|
|
|
|
|
onnx_session = None |
|
|
if os.path.exists(original_onnx): |
|
|
print(f"\n找到原始ONNX模型: {original_onnx}") |
|
|
try: |
|
|
|
|
|
print("准备添加中间输出到ONNX模型...") |
|
|
modified_path = add_intermediate_outputs_to_onnx( |
|
|
original_onnx, |
|
|
list(all_tensor_names), |
|
|
modified_onnx_path |
|
|
) |
|
|
|
|
|
|
|
|
print("创建ONNX推理会话...") |
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device.type == 'cuda' else ['CPUExecutionProvider'] |
|
|
onnx_session = create_onnx_session(modified_path, providers) |
|
|
print(f"ONNX会话创建成功,使用provider: {onnx_session.get_providers()}") |
|
|
except Exception as e: |
|
|
print(f"无法加载ONNX模型: {e}") |
|
|
print("将使用PyTorch模型进行推理") |
|
|
|
|
|
if onnx_session is None: |
|
|
print("\n警告: 未能创建ONNX会话,将只收集基本输入数据") |
|
|
print("这意味着无法获取中间层的数据,只能收集前几个子图的输入") |
|
|
print("如需完整数据,请确保ONNX模型可用") |
|
|
|
|
|
|
|
|
print(f"\n开始推理并收集数据...") |
|
|
print(f"Prompts数量: {len(PROMPTS)}") |
|
|
print(f"推理步数: {num_inference_steps}") |
|
|
|
|
|
subgraph_data = run_inference_and_collect( |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
transformer, |
|
|
scheduler, |
|
|
device, |
|
|
weight_dtype, |
|
|
sub_configs, |
|
|
onnx_session, |
|
|
output_base_dir, |
|
|
skip_existing=args.skip_existing, |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\n保存数据到: {output_base_dir}") |
|
|
tar_paths = save_subgraph_data(subgraph_data, output_base_dir, skip_existing=args.skip_existing) |
|
|
|
|
|
|
|
|
write_tar_list(tar_paths, tar_list_file) |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("完成!") |
|
|
print(f"总共创建了 {len(tar_paths)} 个tar文件") |
|
|
print(f"tar文件列表: {tar_list_file}") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
python examples/z_image_fun/collect_subgraph_inputs.py \ |
|
|
--onnx /path/model.onnx \ |
|
|
--subgraph-config /path/subgraph.json \ |
|
|
--output-dir /data/out \ |
|
|
--tar-list-file /data/out/subgraph_calibration_paths.txt \ |
|
|
--sample-size 640 640 \ |
|
|
--max-seq-len 256 |
|
|
""" |
|
|
torch.set_grad_enabled(False) |
|
|
main() |
|
|
|