| import argparse |
| import os |
| from collections import defaultdict |
| from glob import glob |
| from typing import Any, List, Union |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| import trimesh |
| from huggingface_hub import snapshot_download |
| from PIL import Image, ImageOps |
| from skimage import measure |
| import sys |
| from pathlib import Path |
| current_dir = Path(__file__).parent |
| |
| root_dir = current_dir.parent |
| |
| sys.path.append(str(root_dir)) |
|
|
| from midi.pipelines.pipeline_midi import MIDIPipeline |
| from midi.utils.smoothing import smooth_gpu |
| from midi.sketch.fusion_adapter import FusionAdapterConfig,SketchFusionAdapter |
| from midi.sketch.sketch_tower import SketchVisionTowerConfig,SketchVisionTower |
| from midi.models.transformers.triposg_transformer import TripoSGDiTModel |
| import torch.nn as nn |
|
|
| from midi.sketch.sketch_utils import prepare_sketch_images |
|
|
| def check_meta_tensors(model, verbose: bool = True) -> dict: |
| """ |
| 检查模型中所有张量的设备类型,特别关注 meta device 上的张量 |
| |
| 参数: |
| model: 要检查的 PyTorch 模型 |
| verbose: 是否打印详细结果 |
| |
| 返回: |
| 包含所有张量设备信息的字典 |
| """ |
| device_stats = defaultdict(list) |
|
|
| |
| for name, module in model.named_modules(): |
| |
| for param_name, param in module.named_parameters(recurse=False): |
| device_type = param.device.type |
| device_stats[device_type].append((f"{name}.{param_name}", param.shape)) |
|
|
| if device_type == "meta" and verbose: |
| print(f"⚠️ Meta tensor found: {name}.{param_name} | Shape: {tuple(param.shape)}") |
|
|
| |
| for buffer_name, buffer in module.named_buffers(recurse=False): |
| device_type = buffer.device.type |
| device_stats[device_type].append((f"{name}.{buffer_name}", buffer.shape)) |
|
|
| if device_type == "meta" and verbose: |
| print(f"⚠️ Meta buffer found: {name}.{buffer_name} | Shape: {tuple(buffer.shape)}") |
|
|
| |
| if verbose: |
| print("\n📊 Device Distribution Summary:") |
| for device, tensors in device_stats.items(): |
| print(f"• {device.upper()} device: {len(tensors)} tensors") |
|
|
| |
| if device == "meta" and tensors: |
| print(" Top Meta Tensors:") |
| for tensor_name, shape in tensors[:5]: |
| print(f" - {tensor_name} | Shape: {shape}") |
|
|
| return device_stats |
|
|
|
|
| def check_pipeline_meta_tensors(pipe): |
| """返回 (组件统计, 总meta张量数, 所有meta张量列表)""" |
| component_stats = {} |
| all_meta_tensors = [] |
| total_meta = 0 |
|
|
| components = ['unet', 'vae', 'text_encoder', 'scheduler', 'adapter_layers'] |
| for comp_name in components: |
| if hasattr(pipe, comp_name): |
| comp = getattr(pipe, comp_name) |
| if isinstance(comp, nn.Module): |
| comp_stats = check_meta_tensors(comp, verbose=False) |
| component_stats[comp_name] = comp_stats |
|
|
| comp_meta = comp_stats.get('meta', []) |
| total_meta += len(comp_meta) |
| all_meta_tensors.extend(comp_meta) |
|
|
| |
| print("\n" + "=" * 50) |
| print("MIDI Pipeline Meta Tensor Report") |
| print("=" * 50) |
|
|
| for comp, stats in component_stats.items(): |
| meta_count = len(stats.get('meta', [])) |
| print(f"{comp.upper()}: {meta_count} meta tensors") |
|
|
| print("\n" + "=" * 50) |
| print(f"TOTAL META TENSORS: {total_meta}") |
|
|
| if total_meta > 0: |
| print("🔴 META TENSORS FOUND!") |
| else: |
| print("✅ No tensors found on meta device!") |
| print("=" * 50) |
|
|
| return component_stats, total_meta, all_meta_tensors |
|
|
|
|
| def materialize_meta_tensors(module: nn.Module, device: str): |
| """将元设备张量转移到实际设备""" |
| for name, param in module.named_parameters(): |
| if param.device.type == 'meta': |
| |
| new_param = nn.Parameter( |
| torch.empty_like(param, device=device), |
| requires_grad=param.requires_grad |
| ) |
| setattr(module, name, new_param) |
|
|
| for name, buf in module.named_buffers(): |
| if buf.device.type == 'meta': |
| new_buf = torch.empty_like(buf, device=device) |
| setattr(module, name, new_buf) |
|
|
| |
| for child in module.children(): |
| materialize_meta_tensors(child, device) |
|
|
| return module |
| def prepare_pipeline(device, dtype): |
| local_dir = "pretrained_weights/MIDI-3D" |
| snapshot_download(repo_id="VAST-AI/MIDI-3D", local_dir=local_dir) |
|
|
| |
| |
|
|
| |
| sketch_fusion_adapter_config = FusionAdapterConfig( |
| num_dit_layers=21, |
| num_vision_layers=12, |
| dim_dit_latent=2048, |
| dim_vision_latent=1024, |
| patch_size=14, |
| fusion_mode="specific", |
| dit_layer_seqs=[0,2,4,5,6,8,9,10,12,14,16,18,20], |
| vision_layer_seqs=[0,1,2,3,4,5,6,7,8,9,10,11,12], |
| enable_projector=False, |
| dim_projected_latent=768, |
|
|
| ) |
| sketch_vision_tower_config = SketchVisionTowerConfig( |
| select_feature_type="patch", |
| arbitrary_input_size=False, |
| input_size=(224,224), |
| select_layer="all", |
| dim_vision_latent = 1024, |
| ) |
| transformer = TripoSGDiTModel.from_pretrained(os.path.join(local_dir, "./transformer"), |
| strict=False, |
| ignore_mismatched_sizes=True, |
| low_cpu_mem_usage=False, |
| dit_fusion_layer_seq=[1,3,5,6,7,9,10,11,13,15,17,19,21], |
| sketch_attention_dim=1024) |
|
|
| sketch_fusion_adapter = SketchFusionAdapter(sketch_fusion_adapter_config, |
| sketch_vision_tower_config) |
| |
| |
| pipe: MIDIPipeline = MIDIPipeline.from_pretrained( |
| local_dir, |
| transformer=transformer, |
| sketch_fusion_adapter=sketch_fusion_adapter, |
| strict=False, |
| ) |
| |
| |
|
|
|
|
| pipe.transformer = transformer |
| pipe = pipe.to(device) |
| comp_stats, total_meta, all_meta_tensors = check_pipeline_meta_tensors(pipe) |
| |
| |
| |
| |
| |
| |
|
|
| if total_meta > 0: |
| print(f"\nFound {total_meta} tensors on meta device:") |
| for tensor_info in all_meta_tensors[:5]: |
| name, shape = tensor_info |
| print(f" - {name} | Shape: {shape}") |
| if total_meta > 5: |
| print(f" ... and {total_meta - 5} more.") |
|
|
| |
| print("\nMaterializing meta tensors...") |
| pipe = materialize_meta_tensors(pipe, device=device) |
|
|
| |
| _, post_meta, _ = check_pipeline_meta_tensors(pipe) |
| if post_meta > 0: |
| raise RuntimeError("Failed to materialize all meta tensors!") |
| else: |
| |
| pipe = pipe.to(device) |
|
|
| pipe.init_custom_adapter( |
| |
| set_self_attn_module_names=[ |
| "blocks.8", |
| "blocks.9", |
| "blocks.10", |
| "blocks.11", |
| "blocks.12", |
| ], |
| set_sketch_attn_module_names=[ |
| "blocks.0", |
| "blocks.1", |
| "blocks.2", |
| "blocks.3", |
| "blocks.4", |
| "blocks.5", |
| "blocks.6", |
| "blocks.7", |
| "blocks.8", |
| "blocks.9", |
| "blocks.10", |
| "blocks.11", |
| "blocks.12", |
| "blocks.13", |
| "blocks.14", |
| "blocks.15", |
| "blocks.16", |
| "blocks.17", |
| "blocks.18", |
| "blocks.19", |
| "blocks.20", |
| "blocks.21", |
|
|
| ] |
| ) |
| return pipe |
|
|
|
|
| def preprocess_image(rgb_image, seg_image): |
| if isinstance(rgb_image, str): |
| rgb_image = Image.open(rgb_image) |
| if isinstance(seg_image, str): |
| seg_image = Image.open(seg_image) |
| rgb_image = rgb_image.convert("RGB") |
| seg_image = seg_image.convert("L") |
|
|
| width, height = rgb_image.size |
|
|
| seg_np = np.array(seg_image) |
| rows, cols = np.where(seg_np > 0) |
| if rows.size == 0 or cols.size == 0: |
| return rgb_image, seg_image |
|
|
| |
| min_row, max_row = min(rows), max(rows) |
| min_col, max_col = min(cols), max(cols) |
| L = max( |
| max(abs(max_row - width // 2), abs(min_row - width // 2)) * 2, |
| max(abs(max_col - height // 2), abs(min_col - height // 2)) * 2, |
| ) |
|
|
| |
| if L > width * 0.8: |
| width = int(L / 4 * 5) |
| if L > height * 0.8: |
| height = int(L / 4 * 5) |
| rgb_new = Image.new("RGB", (width, height), (255, 255, 255)) |
| seg_new = Image.new("L", (width, height), 0) |
| x_offset = (width - rgb_image.size[0]) // 2 |
| y_offset = (height - rgb_image.size[1]) // 2 |
| rgb_new.paste(rgb_image, (x_offset, y_offset)) |
| seg_new.paste(seg_image, (x_offset, y_offset)) |
|
|
| |
| max_dim = max(width, height) |
| rgb_new = ImageOps.expand( |
| rgb_new, border=(0, 0, max_dim - width, max_dim - height), fill="white" |
| ) |
| seg_new = ImageOps.expand( |
| seg_new, border=(0, 0, max_dim - width, max_dim - height), fill=0 |
| ) |
|
|
| return rgb_new, seg_new |
|
|
|
|
| def split_rgb_mask(rgb_image, seg_image): |
| if isinstance(rgb_image, str): |
| rgb_image = Image.open(rgb_image) |
| if isinstance(seg_image, str): |
| seg_image = Image.open(seg_image) |
| rgb_image = rgb_image.convert("RGB") |
| seg_image = seg_image.convert("L") |
|
|
| rgb_array = np.array(rgb_image) |
| seg_array = np.array(seg_image) |
|
|
| label_ids = np.unique(seg_array) |
| label_ids = label_ids[label_ids > 0] |
|
|
| instance_rgbs, instance_masks, scene_rgbs = [], [], [] |
|
|
| for segment_id in sorted(label_ids): |
| |
| white_background = np.ones_like(rgb_array) * 255 |
|
|
| mask = np.zeros_like(seg_array, dtype=np.uint8) |
| mask[seg_array == segment_id] = 255 |
| segment_rgb = white_background.copy() |
| segment_rgb[mask == 255] = rgb_array[mask == 255] |
|
|
| segment_rgb_image = Image.fromarray(segment_rgb) |
| segment_mask_image = Image.fromarray(mask) |
| instance_rgbs.append(segment_rgb_image) |
| instance_masks.append(segment_mask_image) |
| scene_rgbs.append(rgb_image) |
|
|
| return instance_rgbs, instance_masks, scene_rgbs |
|
|
|
|
| @torch.no_grad() |
| def run_midi( |
| pipe: Any, |
| rgb_image: Union[str, Image.Image], |
| seg_image: Union[str, Image.Image], |
| sketch_image: Union[str, Image.Image], |
| seed: int, |
| num_inference_steps: int = 50, |
| guidance_scale: float = 7.0, |
| do_image_padding: bool = False, |
| ) -> trimesh.Scene: |
| if do_image_padding: |
| rgb_image, seg_image = preprocess_image(rgb_image, seg_image) |
| instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image) |
| sketch_image_list = prepare_sketch_images( |
| sketch_image, |
| 5, |
| 1, |
| seg_images=[instance_masks], |
| mode="zoom", |
| cfg=False, |
| ) |
| for idx, i in enumerate(sketch_image_list[0]): |
| i.save(f"test3_zoom_{str(idx)}.png", format="PNG") |
| pipe_kwargs = {} |
| if seed != -1 and isinstance(seed, int): |
| pipe_kwargs["generator"] = torch.Generator(device=pipe.device).manual_seed(seed) |
|
|
| num_instances = len(instance_rgbs) |
| outputs = pipe( |
| "TEST1", |
| image=instance_rgbs, |
| mask=instance_masks, |
| image_scene=scene_rgbs, |
| sketch_image=sketch_image_list, |
| attention_kwargs={"num_instances": num_instances}, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| decode_progressive=True, |
| return_dict=False, |
| **pipe_kwargs, |
| ) |
|
|
| |
| trimeshes = [] |
| for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate( |
| zip(*outputs) |
| ): |
| grid_logits = logits_.view(grid_size) |
| grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1) |
| torch.cuda.empty_cache() |
| vertices, faces, normals, _ = measure.marching_cubes( |
| grid_logits.float().cpu().numpy(), 0, method="lewiner" |
| ) |
| vertices = vertices / grid_size * bbox_size + bbox_min |
|
|
| |
| mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces)) |
| trimeshes.append(mesh) |
|
|
| |
| scene = trimesh.Scene(trimeshes) |
|
|
| return scene |
|
|
|
|
| if __name__ == "__main__": |
| device = "cuda" |
| dtype = torch.bfloat16 |
| checkpoint_path = "" |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--rgb", type=str, required=True) |
| parser.add_argument("--seg", type=str, required=True) |
| parser.add_argument("--sketch", type=str, required=True) |
|
|
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--num-inference-steps", type=int, default=50) |
| parser.add_argument("--guidance-scale", type=float, default=7.0) |
| parser.add_argument("--do-image-padding", action="store_true") |
| parser.add_argument("--output-dir", type=str, default="./") |
| args = parser.parse_args() |
|
|
| sketch_image = Image.open(args.sketch) |
| sketch_array = np.array(sketch_image) |
| if sketch_array.ndim == 3 and sketch_array.shape[2] == 3: |
| |
| channels_identical = np.array_equal(sketch_array[:, :, 0], sketch_array[:, :, 1]) and \ |
| np.array_equal(sketch_array[:, :, 1], sketch_array[:, :, 2]) |
| print(f"三个通道是否完全相同: {channels_identical}") |
|
|
| sketch_image_list = [sketch_image] |
| pipe = prepare_pipeline(device, dtype) |
| run_midi( |
| pipe, |
| rgb_image=args.rgb, |
| seg_image=args.seg, |
| sketch_image=sketch_image_list, |
| seed=args.seed, |
| num_inference_steps=args.num_inference_steps, |
| guidance_scale=args.guidance_scale, |
| do_image_padding=args.do_image_padding, |
| ).export(os.path.join(args.output_dir, "output.glb")) |
|
|