import argparse import os from collections import defaultdict from glob import glob from typing import Any, List, Union import gradio as gr import numpy as np import torch import trimesh from huggingface_hub import snapshot_download from PIL import Image, ImageOps from skimage import measure import sys from pathlib import Path current_dir = Path(__file__).parent # 获取项目根目录(即midi和scripts的父目录) root_dir = current_dir.parent # 将根目录添加到sys.path sys.path.append(str(root_dir)) from midi.pipelines.pipeline_midi import MIDIPipeline from midi.utils.smoothing import smooth_gpu from midi.sketch.fusion_adapter import FusionAdapterConfig,SketchFusionAdapter from midi.sketch.sketch_tower import SketchVisionTowerConfig,SketchVisionTower from midi.models.transformers.triposg_transformer import TripoSGDiTModel import torch.nn as nn from midi.sketch.sketch_utils import prepare_sketch_images def check_meta_tensors(model, verbose: bool = True) -> dict: """ 检查模型中所有张量的设备类型,特别关注 meta device 上的张量 参数: model: 要检查的 PyTorch 模型 verbose: 是否打印详细结果 返回: 包含所有张量设备信息的字典 """ device_stats = defaultdict(list) # 递归遍历所有模块 for name, module in model.named_modules(): # 检查模块参数 for param_name, param in module.named_parameters(recurse=False): device_type = param.device.type device_stats[device_type].append((f"{name}.{param_name}", param.shape)) if device_type == "meta" and verbose: print(f"⚠️ Meta tensor found: {name}.{param_name} | Shape: {tuple(param.shape)}") # 检查模块缓冲区 for buffer_name, buffer in module.named_buffers(recurse=False): device_type = buffer.device.type device_stats[device_type].append((f"{name}.{buffer_name}", buffer.shape)) if device_type == "meta" and verbose: print(f"⚠️ Meta buffer found: {name}.{buffer_name} | Shape: {tuple(buffer.shape)}") # 打印汇总统计 if verbose: print("\n📊 Device Distribution Summary:") for device, tensors in device_stats.items(): print(f"• {device.upper()} device: {len(tensors)} tensors") # 显示前5个meta张量 if device == "meta" and tensors: print(" Top Meta Tensors:") for tensor_name, shape in tensors[:5]: print(f" - {tensor_name} | Shape: {shape}") return device_stats def check_pipeline_meta_tensors(pipe): """返回 (组件统计, 总meta张量数, 所有meta张量列表)""" component_stats = {} all_meta_tensors = [] total_meta = 0 components = ['unet', 'vae', 'text_encoder', 'scheduler', 'adapter_layers'] for comp_name in components: if hasattr(pipe, comp_name): comp = getattr(pipe, comp_name) if isinstance(comp, nn.Module): comp_stats = check_meta_tensors(comp, verbose=False) component_stats[comp_name] = comp_stats comp_meta = comp_stats.get('meta', []) total_meta += len(comp_meta) all_meta_tensors.extend(comp_meta) # 修正报告输出 print("\n" + "=" * 50) print("MIDI Pipeline Meta Tensor Report") print("=" * 50) for comp, stats in component_stats.items(): meta_count = len(stats.get('meta', [])) print(f"{comp.upper()}: {meta_count} meta tensors") print("\n" + "=" * 50) print(f"TOTAL META TENSORS: {total_meta}") if total_meta > 0: print("🔴 META TENSORS FOUND!") else: print("✅ No tensors found on meta device!") print("=" * 50) return component_stats, total_meta, all_meta_tensors def materialize_meta_tensors(module: nn.Module, device: str): """将元设备张量转移到实际设备""" for name, param in module.named_parameters(): if param.device.type == 'meta': # 创建空张量(不初始化) new_param = nn.Parameter( torch.empty_like(param, device=device), requires_grad=param.requires_grad ) setattr(module, name, new_param) for name, buf in module.named_buffers(): if buf.device.type == 'meta': new_buf = torch.empty_like(buf, device=device) setattr(module, name, new_buf) # 递归处理子模块 for child in module.children(): materialize_meta_tensors(child, device) return module def prepare_pipeline(device, dtype): local_dir = "pretrained_weights/MIDI-3D" snapshot_download(repo_id="VAST-AI/MIDI-3D", local_dir=local_dir) # Original implementation # pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(device, dtype) # Sketch Test sketch_fusion_adapter_config = FusionAdapterConfig( num_dit_layers=21, num_vision_layers=12, dim_dit_latent=2048, dim_vision_latent=1024, patch_size=14, fusion_mode="specific", dit_layer_seqs=[0,2,4,5,6,8,9,10,12,14,16,18,20], vision_layer_seqs=[0,1,2,3,4,5,6,7,8,9,10,11,12], enable_projector=False, dim_projected_latent=768, ) sketch_vision_tower_config = SketchVisionTowerConfig( select_feature_type="patch", arbitrary_input_size=False, input_size=(224,224), select_layer="all", dim_vision_latent = 1024, ) transformer = TripoSGDiTModel.from_pretrained(os.path.join(local_dir, "./transformer"), strict=False, ignore_mismatched_sizes=True, low_cpu_mem_usage=False, dit_fusion_layer_seq=[1,3,5,6,7,9,10,11,13,15,17,19,21], sketch_attention_dim=1024) sketch_fusion_adapter = SketchFusionAdapter(sketch_fusion_adapter_config, sketch_vision_tower_config) # for name, param in transformer.named_parameters(): # print('FIRST', name, '| ', param.device) pipe: MIDIPipeline = MIDIPipeline.from_pretrained( local_dir, transformer=transformer, sketch_fusion_adapter=sketch_fusion_adapter, strict=False, ) # for name, param in transformer.named_parameters(): # print('THIRD', name, '| ', param.device) pipe.transformer = transformer pipe = pipe.to(device) comp_stats, total_meta, all_meta_tensors = check_pipeline_meta_tensors(pipe) # print(pipe.vae.encoder.block) # print(type(pipe.vae)) # for name, param in pipe.vae.named_parameters(): # print(name, '| ', param.device) # for name, param in transformer.named_parameters(): # print(name, '| ', param.device) if total_meta > 0: print(f"\nFound {total_meta} tensors on meta device:") for tensor_info in all_meta_tensors[:5]: # 只显示前5个 name, shape = tensor_info print(f" - {name} | Shape: {shape}") if total_meta > 5: print(f" ... and {total_meta - 5} more.") # 实际处理meta张量 print("\nMaterializing meta tensors...") pipe = materialize_meta_tensors(pipe, device=device) # 验证处理结果 _, post_meta, _ = check_pipeline_meta_tensors(pipe) if post_meta > 0: raise RuntimeError("Failed to materialize all meta tensors!") else: # 直接转移到目标设备 pipe = pipe.to(device) pipe.init_custom_adapter( # @TODO: 由于SketchFusionAdapter暂时就是普通cross attention,所以这里还不用设置。 set_self_attn_module_names=[ "blocks.8", "blocks.9", "blocks.10", "blocks.11", "blocks.12", ], set_sketch_attn_module_names=[ "blocks.0", "blocks.1", "blocks.2", "blocks.3", "blocks.4", "blocks.5", "blocks.6", "blocks.7", "blocks.8", "blocks.9", "blocks.10", "blocks.11", "blocks.12", "blocks.13", "blocks.14", "blocks.15", "blocks.16", "blocks.17", "blocks.18", "blocks.19", "blocks.20", "blocks.21", ] ) return pipe def preprocess_image(rgb_image, seg_image): if isinstance(rgb_image, str): rgb_image = Image.open(rgb_image) if isinstance(seg_image, str): seg_image = Image.open(seg_image) rgb_image = rgb_image.convert("RGB") seg_image = seg_image.convert("L") width, height = rgb_image.size seg_np = np.array(seg_image) rows, cols = np.where(seg_np > 0) if rows.size == 0 or cols.size == 0: return rgb_image, seg_image # compute the bounding box of combined instances min_row, max_row = min(rows), max(rows) min_col, max_col = min(cols), max(cols) L = max( max(abs(max_row - width // 2), abs(min_row - width // 2)) * 2, max(abs(max_col - height // 2), abs(min_col - height // 2)) * 2, ) # pad the image if L > width * 0.8: width = int(L / 4 * 5) if L > height * 0.8: height = int(L / 4 * 5) rgb_new = Image.new("RGB", (width, height), (255, 255, 255)) seg_new = Image.new("L", (width, height), 0) x_offset = (width - rgb_image.size[0]) // 2 y_offset = (height - rgb_image.size[1]) // 2 rgb_new.paste(rgb_image, (x_offset, y_offset)) seg_new.paste(seg_image, (x_offset, y_offset)) # pad to the square max_dim = max(width, height) rgb_new = ImageOps.expand( rgb_new, border=(0, 0, max_dim - width, max_dim - height), fill="white" ) seg_new = ImageOps.expand( seg_new, border=(0, 0, max_dim - width, max_dim - height), fill=0 ) return rgb_new, seg_new def split_rgb_mask(rgb_image, seg_image): if isinstance(rgb_image, str): rgb_image = Image.open(rgb_image) if isinstance(seg_image, str): seg_image = Image.open(seg_image) rgb_image = rgb_image.convert("RGB") seg_image = seg_image.convert("L") rgb_array = np.array(rgb_image) seg_array = np.array(seg_image) label_ids = np.unique(seg_array) label_ids = label_ids[label_ids > 0] instance_rgbs, instance_masks, scene_rgbs = [], [], [] for segment_id in sorted(label_ids): # Here we set the background to white white_background = np.ones_like(rgb_array) * 255 mask = np.zeros_like(seg_array, dtype=np.uint8) mask[seg_array == segment_id] = 255 segment_rgb = white_background.copy() segment_rgb[mask == 255] = rgb_array[mask == 255] segment_rgb_image = Image.fromarray(segment_rgb) segment_mask_image = Image.fromarray(mask) instance_rgbs.append(segment_rgb_image) instance_masks.append(segment_mask_image) scene_rgbs.append(rgb_image) return instance_rgbs, instance_masks, scene_rgbs @torch.no_grad() def run_midi( pipe: Any, rgb_image: Union[str, Image.Image], seg_image: Union[str, Image.Image], sketch_image: Union[str, Image.Image], seed: int, num_inference_steps: int = 50, guidance_scale: float = 7.0, do_image_padding: bool = False, ) -> trimesh.Scene: if do_image_padding: rgb_image, seg_image = preprocess_image(rgb_image, seg_image) instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image) sketch_image_list = prepare_sketch_images( sketch_image, 5, 1, seg_images=[instance_masks], mode="zoom", cfg=False, ) for idx, i in enumerate(sketch_image_list[0]): i.save(f"test3_zoom_{str(idx)}.png", format="PNG") pipe_kwargs = {} if seed != -1 and isinstance(seed, int): pipe_kwargs["generator"] = torch.Generator(device=pipe.device).manual_seed(seed) num_instances = len(instance_rgbs) outputs = pipe( "TEST1", image=instance_rgbs, mask=instance_masks, image_scene=scene_rgbs, sketch_image=sketch_image_list, attention_kwargs={"num_instances": num_instances}, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, decode_progressive=True, return_dict=False, **pipe_kwargs, ) # marching cubes trimeshes = [] for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate( zip(*outputs) ): grid_logits = logits_.view(grid_size) grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1) torch.cuda.empty_cache() vertices, faces, normals, _ = measure.marching_cubes( grid_logits.float().cpu().numpy(), 0, method="lewiner" ) vertices = vertices / grid_size * bbox_size + bbox_min # Trimesh mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces)) trimeshes.append(mesh) # compose the output meshes scene = trimesh.Scene(trimeshes) return scene if __name__ == "__main__": device = "cuda" dtype = torch.bfloat16 checkpoint_path = "" parser = argparse.ArgumentParser() parser.add_argument("--rgb", type=str, required=True) parser.add_argument("--seg", type=str, required=True) parser.add_argument("--sketch", type=str, required=True) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--num-inference-steps", type=int, default=50) parser.add_argument("--guidance-scale", type=float, default=7.0) parser.add_argument("--do-image-padding", action="store_true") parser.add_argument("--output-dir", type=str, default="./") args = parser.parse_args() sketch_image = Image.open(args.sketch) sketch_array = np.array(sketch_image) if sketch_array.ndim == 3 and sketch_array.shape[2] == 3: # 确保是RGB图像 # 方法1:比较所有通道是否完全相同 channels_identical = np.array_equal(sketch_array[:, :, 0], sketch_array[:, :, 1]) and \ np.array_equal(sketch_array[:, :, 1], sketch_array[:, :, 2]) print(f"三个通道是否完全相同: {channels_identical}") sketch_image_list = [sketch_image] pipe = prepare_pipeline(device, dtype) run_midi( pipe, rgb_image=args.rgb, seg_image=args.seg, sketch_image=sketch_image_list, seed=args.seed, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, do_image_padding=args.do_image_padding, ).export(os.path.join(args.output_dir, "output.glb"))