madtune commited on
Commit
3a32a52
·
verified ·
1 Parent(s): c522d57

Delete pixeldit/modeling_pixeldit.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixeldit/modeling_pixeldit.py +0 -47
pixeldit/modeling_pixeldit.py DELETED
@@ -1,47 +0,0 @@
1
- """
2
- PixelDiT model loader.
3
-
4
- Usage:
5
- from modeling_pixeldit import load_pixeldit
6
- model = load_pixeldit()
7
- out = model(x, t, y) # [B,3,H,W], [B], [B,300,2304] -> [B,3,H,W]
8
- """
9
-
10
- import sys
11
- import torch
12
-
13
- sys.path.insert(0, "/home/nobus/Raid0/PixelDiT")
14
- from pixdit_core.pixeldit_t2i import PixDiT_T2I
15
-
16
- _CKPT = (
17
- "/home/nobus/.cache/huggingface/hub/"
18
- "models--nvidia--PixelDiT-1300M-1024px/snapshots/"
19
- "7c63b99a7a399918a1d6478b095698a65f664847/pixeldit_t2i_v1.pth"
20
- )
21
-
22
- _ARCH = dict(
23
- in_channels=3,
24
- num_groups=24,
25
- hidden_size=1536,
26
- pixel_hidden_size=16,
27
- pixel_attn_hidden_size=1152,
28
- pixel_num_groups=16,
29
- patch_depth=14,
30
- pixel_depth=2,
31
- patch_size=16,
32
- txt_embed_dim=2304,
33
- txt_max_length=300,
34
- )
35
-
36
-
37
- def load_pixeldit(checkpoint=_CKPT, device="cuda", dtype=torch.bfloat16):
38
- model = PixDiT_T2I(**_ARCH)
39
- state = torch.load(checkpoint, map_location="cpu", weights_only=False)
40
- sd = state.get("state_dict", state)
41
- sd = {(k[5:] if k.startswith("core.") else k): v for k, v in sd.items()}
42
- missing, _ = model.load_state_dict(sd, strict=False)
43
- if missing:
44
- print(f"[modeling] {len(missing)} missing keys (expected)")
45
- model = model.to(device).to(dtype).eval()
46
- print(f"[modeling] PixelDiT loaded — {sum(p.numel() for p in model.parameters()):,} params")
47
- return model