madtune commited on
Commit
d9d6c63
·
verified ·
1 Parent(s): 3a32a52

Delete pixeldit/modeling_pixeldit_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixeldit/modeling_pixeldit_hf.py +0 -75
pixeldit/modeling_pixeldit_hf.py DELETED
@@ -1,75 +0,0 @@
1
- """
2
- HF-compatible PixelDiT wrapper.
3
-
4
- Allows save_pretrained / from_pretrained and peft LoRA targeting.
5
-
6
- Usage:
7
- # Convert from original .pth
8
- model = PixelDiTModel.from_pth("pixeldit_t2i_v1.pth")
9
- model.save_pretrained("pixeldit-diffusers/")
10
-
11
- # Load back
12
- model = PixelDiTModel.from_pretrained("pixeldit-diffusers/")
13
-
14
- # LoRA
15
- from peft import get_peft_model, LoraConfig
16
- lora_cfg = LoraConfig(target_modules=["qkv_x", "qkv_y", "proj_x", "proj_y"])
17
- model = get_peft_model(model, lora_cfg)
18
- """
19
-
20
- import sys
21
- import torch
22
- from transformers import PreTrainedModel
23
-
24
- sys.path.insert(0, "/home/nobus/Raid0/PixelDiT")
25
- from pixdit_core.pixeldit_t2i import PixDiT_T2I
26
-
27
- from .configuration_pixeldit import PixelDiTConfig
28
-
29
-
30
- class PixelDiTModel(PreTrainedModel):
31
- config_class = PixelDiTConfig
32
- _tied_weights_keys = []
33
-
34
- @property
35
- def all_tied_weights_keys(self):
36
- return {}
37
-
38
- def __init__(self, config: PixelDiTConfig):
39
- super().__init__(config)
40
- self.model = PixDiT_T2I(
41
- in_channels = config.in_channels,
42
- num_groups = config.num_groups,
43
- hidden_size = config.hidden_size,
44
- pixel_hidden_size = config.pixel_hidden_size,
45
- pixel_attn_hidden_size = config.pixel_attn_hidden_size,
46
- pixel_num_groups = config.pixel_num_groups,
47
- patch_depth = config.patch_depth,
48
- pixel_depth = config.pixel_depth,
49
- num_text_blocks = config.num_text_blocks,
50
- patch_size = config.patch_size,
51
- txt_embed_dim = config.txt_embed_dim,
52
- txt_max_length = config.txt_max_length,
53
- use_text_rope = config.use_text_rope,
54
- text_rope_theta = config.text_rope_theta,
55
- repa_encoder_index = config.repa_encoder_index,
56
- use_pixel_abs_pos = config.use_pixel_abs_pos,
57
- )
58
-
59
- def forward(self, x, t, y, s=None, mask=None):
60
- return self.model(x, t, y, s=s, mask=mask)
61
-
62
- @classmethod
63
- def from_pth(cls, pth_path: str, config: PixelDiTConfig = None):
64
- """Load from original nvidia .pth checkpoint, handles core. prefix."""
65
- if config is None:
66
- config = PixelDiTConfig()
67
- model = cls(config)
68
- state = torch.load(pth_path, map_location="cpu", weights_only=False)
69
- sd = state.get("state_dict", state)
70
- # strip trainer wrapper prefix, then add HF model. prefix
71
- sd = {(k[5:] if k.startswith("core.") else k): v for k, v in sd.items()}
72
- sd = {"model." + k: v for k, v in sd.items()}
73
- missing, unexpected = model.load_state_dict(sd, strict=False)
74
- print(f"[PixelDiTModel.from_pth] loaded — {len(missing)} missing, {len(unexpected)} unexpected")
75
- return model