madtune commited on
Commit
e9200bf
·
verified ·
1 Parent(s): fe7e8a6

Upload folder using huggingface_hub

Browse files
pixeldit/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pipeline import PixelDiTPipeline
2
+ from .modeling_pixeldit import load_pixeldit
3
+ from .modeling_pixeldit_hf import PixelDiTModel
4
+ from .configuration_pixeldit import PixelDiTConfig
5
+ from .text_encoder_gemma import GemmaEncoder
6
+ from .text_encoder_qwen import QwenEncoder
7
+ from .scheduling_flow import FlowScheduler
8
+
9
+ __all__ = [
10
+ "PixelDiTPipeline",
11
+ "load_pixeldit",
12
+ "PixelDiTModel",
13
+ "PixelDiTConfig",
14
+ "GemmaEncoder",
15
+ "QwenEncoder",
16
+ "FlowScheduler",
17
+ ]
pixeldit/configuration_pixeldit.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class PixelDiTConfig(PretrainedConfig):
5
+ model_type = "pixeldit"
6
+
7
+ def __init__(
8
+ self,
9
+ in_channels=3,
10
+ num_groups=24,
11
+ hidden_size=1536,
12
+ pixel_hidden_size=16,
13
+ pixel_attn_hidden_size=1152,
14
+ pixel_num_groups=16,
15
+ patch_depth=14,
16
+ pixel_depth=2,
17
+ num_text_blocks=4,
18
+ patch_size=16,
19
+ txt_embed_dim=2304,
20
+ txt_max_length=300,
21
+ use_text_rope=True,
22
+ text_rope_theta=10000.0,
23
+ repa_encoder_index=-1,
24
+ use_pixel_abs_pos=True,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.in_channels = in_channels
29
+ self.num_groups = num_groups
30
+ self.hidden_size = hidden_size
31
+ self.pixel_hidden_size = pixel_hidden_size
32
+ self.pixel_attn_hidden_size = pixel_attn_hidden_size
33
+ self.pixel_num_groups = pixel_num_groups
34
+ self.patch_depth = patch_depth
35
+ self.pixel_depth = pixel_depth
36
+ self.num_text_blocks = num_text_blocks
37
+ self.patch_size = patch_size
38
+ self.txt_embed_dim = txt_embed_dim
39
+ self.txt_max_length = txt_max_length
40
+ self.use_text_rope = use_text_rope
41
+ self.text_rope_theta = text_rope_theta
42
+ self.repa_encoder_index = repa_encoder_index
43
+ self.use_pixel_abs_pos = use_pixel_abs_pos
pixeldit/modeling_pixeldit.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PixelDiT model loader.
3
+
4
+ Usage:
5
+ from modeling_pixeldit import load_pixeldit
6
+ model = load_pixeldit()
7
+ out = model(x, t, y) # [B,3,H,W], [B], [B,300,2304] -> [B,3,H,W]
8
+ """
9
+
10
+ import sys
11
+ import torch
12
+
13
+ sys.path.insert(0, "/home/nobus/Raid0/PixelDiT")
14
+ from pixdit_core.pixeldit_t2i import PixDiT_T2I
15
+
16
+ _CKPT = (
17
+ "/home/nobus/.cache/huggingface/hub/"
18
+ "models--nvidia--PixelDiT-1300M-1024px/snapshots/"
19
+ "7c63b99a7a399918a1d6478b095698a65f664847/pixeldit_t2i_v1.pth"
20
+ )
21
+
22
+ _ARCH = dict(
23
+ in_channels=3,
24
+ num_groups=24,
25
+ hidden_size=1536,
26
+ pixel_hidden_size=16,
27
+ pixel_attn_hidden_size=1152,
28
+ pixel_num_groups=16,
29
+ patch_depth=14,
30
+ pixel_depth=2,
31
+ patch_size=16,
32
+ txt_embed_dim=2304,
33
+ txt_max_length=300,
34
+ )
35
+
36
+
37
+ def load_pixeldit(checkpoint=_CKPT, device="cuda", dtype=torch.bfloat16):
38
+ model = PixDiT_T2I(**_ARCH)
39
+ state = torch.load(checkpoint, map_location="cpu", weights_only=False)
40
+ sd = state.get("state_dict", state)
41
+ sd = {(k[5:] if k.startswith("core.") else k): v for k, v in sd.items()}
42
+ missing, _ = model.load_state_dict(sd, strict=False)
43
+ if missing:
44
+ print(f"[modeling] {len(missing)} missing keys (expected)")
45
+ model = model.to(device).to(dtype).eval()
46
+ print(f"[modeling] PixelDiT loaded — {sum(p.numel() for p in model.parameters()):,} params")
47
+ return model
pixeldit/modeling_pixeldit_hf.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF-compatible PixelDiT wrapper.
3
+
4
+ Allows save_pretrained / from_pretrained and peft LoRA targeting.
5
+
6
+ Usage:
7
+ # Convert from original .pth
8
+ model = PixelDiTModel.from_pth("pixeldit_t2i_v1.pth")
9
+ model.save_pretrained("pixeldit-diffusers/")
10
+
11
+ # Load back
12
+ model = PixelDiTModel.from_pretrained("pixeldit-diffusers/")
13
+
14
+ # LoRA
15
+ from peft import get_peft_model, LoraConfig
16
+ lora_cfg = LoraConfig(target_modules=["qkv_x", "qkv_y", "proj_x", "proj_y"])
17
+ model = get_peft_model(model, lora_cfg)
18
+ """
19
+
20
+ import sys
21
+ import torch
22
+ from transformers import PreTrainedModel
23
+
24
+ sys.path.insert(0, "/home/nobus/Raid0/PixelDiT")
25
+ from pixdit_core.pixeldit_t2i import PixDiT_T2I
26
+
27
+ from .configuration_pixeldit import PixelDiTConfig
28
+
29
+
30
+ class PixelDiTModel(PreTrainedModel):
31
+ config_class = PixelDiTConfig
32
+ _tied_weights_keys = []
33
+
34
+ @property
35
+ def all_tied_weights_keys(self):
36
+ return {}
37
+
38
+ def __init__(self, config: PixelDiTConfig):
39
+ super().__init__(config)
40
+ self.model = PixDiT_T2I(
41
+ in_channels = config.in_channels,
42
+ num_groups = config.num_groups,
43
+ hidden_size = config.hidden_size,
44
+ pixel_hidden_size = config.pixel_hidden_size,
45
+ pixel_attn_hidden_size = config.pixel_attn_hidden_size,
46
+ pixel_num_groups = config.pixel_num_groups,
47
+ patch_depth = config.patch_depth,
48
+ pixel_depth = config.pixel_depth,
49
+ num_text_blocks = config.num_text_blocks,
50
+ patch_size = config.patch_size,
51
+ txt_embed_dim = config.txt_embed_dim,
52
+ txt_max_length = config.txt_max_length,
53
+ use_text_rope = config.use_text_rope,
54
+ text_rope_theta = config.text_rope_theta,
55
+ repa_encoder_index = config.repa_encoder_index,
56
+ use_pixel_abs_pos = config.use_pixel_abs_pos,
57
+ )
58
+
59
+ def forward(self, x, t, y, s=None, mask=None):
60
+ return self.model(x, t, y, s=s, mask=mask)
61
+
62
+ @classmethod
63
+ def from_pth(cls, pth_path: str, config: PixelDiTConfig = None):
64
+ """Load from original nvidia .pth checkpoint, handles core. prefix."""
65
+ if config is None:
66
+ config = PixelDiTConfig()
67
+ model = cls(config)
68
+ state = torch.load(pth_path, map_location="cpu", weights_only=False)
69
+ sd = state.get("state_dict", state)
70
+ # strip trainer wrapper prefix, then add HF model. prefix
71
+ sd = {(k[5:] if k.startswith("core.") else k): v for k, v in sd.items()}
72
+ sd = {"model." + k: v for k, v in sd.items()}
73
+ missing, unexpected = model.load_state_dict(sd, strict=False)
74
+ print(f"[PixelDiTModel.from_pth] loaded — {len(missing)} missing, {len(unexpected)} unexpected")
75
+ return model
pixeldit/pipeline.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PixelDiT T2I Pipeline — thin orchestrator.
3
+
4
+ Usage:
5
+ from pipeline import PixelDiTPipeline
6
+ pipe = PixelDiTPipeline()
7
+ images = pipe("a viking at sunset", height=512, width=512)
8
+ images[0].save("out.jpg")
9
+ """
10
+
11
+ import torch
12
+ from PIL import Image
13
+ from .modeling_pixeldit import load_pixeldit
14
+ from .modeling_pixeldit_hf import PixelDiTModel
15
+ from .text_encoder_gemma import GemmaEncoder
16
+ from .text_encoder_qwen import QwenEncoder
17
+ from .scheduling_flow import FlowScheduler
18
+
19
+
20
+ class PixelDiTPipeline:
21
+ def __init__(
22
+ self,
23
+ text_encoder="gemma", # "gemma" | "qwen"
24
+ qwen_proj=None,
25
+ device="cuda",
26
+ dtype=torch.bfloat16,
27
+ cfg=3.5,
28
+ flow_shift=4.0,
29
+ pretrained=None, # HF dir or repo id — loads via from_pretrained instead of .pth
30
+ ):
31
+ self.device = torch.device(device)
32
+ self.dtype = dtype
33
+
34
+ if text_encoder == "qwen":
35
+ self.encoder = QwenEncoder(proj_path=qwen_proj, output_device=device, output_dtype=dtype)
36
+ else:
37
+ self.encoder = GemmaEncoder(output_device=device, output_dtype=dtype)
38
+
39
+ if pretrained is not None:
40
+ print(f"[pipeline] loading from HF: {pretrained}")
41
+ self.model = (
42
+ PixelDiTModel.from_pretrained(pretrained)
43
+ .to(device).to(dtype).eval()
44
+ )
45
+ else:
46
+ self.model = load_pixeldit(device=device, dtype=dtype)
47
+
48
+ self.scheduler = FlowScheduler(self.model, cfg=cfg, flow_shift=flow_shift)
49
+
50
+ @torch.no_grad()
51
+ def __call__(
52
+ self,
53
+ prompt,
54
+ negative_prompt="",
55
+ height=512,
56
+ width=512,
57
+ steps=20,
58
+ cfg=None,
59
+ seed=None,
60
+ ):
61
+ if isinstance(prompt, str):
62
+ prompts = [prompt]
63
+ else:
64
+ prompts = list(prompt)
65
+ B = len(prompts)
66
+
67
+ if cfg is not None:
68
+ self.scheduler.cfg = cfg
69
+ if seed is not None:
70
+ torch.manual_seed(seed)
71
+
72
+ cond = self.encoder.encode(prompts)
73
+ uncond = (self.encoder.encode_null(B) if not negative_prompt
74
+ else self.encoder.encode([negative_prompt] * B))
75
+
76
+ noise = torch.randn(B, 3, height, width, device=self.device, dtype=self.dtype)
77
+ imgs = self.scheduler.sample(noise, cond, uncond, steps=steps)
78
+
79
+ imgs = (imgs.clamp(-1, 1) + 1) / 2
80
+ imgs = (imgs * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
81
+ return [Image.fromarray(img) for img in imgs]
pixeldit/scheduling_flow.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flow-matching DPM-Solver++ sampler for PixelDiT.
3
+
4
+ Wraps the original DPMS from the PixelDiT repo.
5
+ Order=2 multistep gets quality at 20 steps that Euler needs 100+ for.
6
+
7
+ Usage:
8
+ from scheduling_flow import FlowScheduler
9
+
10
+ scheduler = FlowScheduler(model_fn, cfg=3.5, flow_shift=4.0)
11
+ image = scheduler.sample(noise, cond, uncond, steps=20)
12
+ """
13
+
14
+ import sys
15
+ import torch
16
+ from tqdm import tqdm
17
+
18
+ sys.path.insert(0, "/home/nobus/Raid0/PixelDiT/t2i")
19
+ from diffusion.model.flow_dpm import DPMS
20
+
21
+ _FLOW_SHIFT = 4.0 # 1024px stage-3 config
22
+
23
+
24
+ class FlowScheduler:
25
+ def __init__(self, model_fn, cfg=3.5, flow_shift=_FLOW_SHIFT):
26
+ """
27
+ model_fn: callable(x, t, y) -> velocity [B,3,H,W]
28
+ cfg: classifier-free guidance scale
29
+ """
30
+ # DPMS passes y as [B,1,L,D] but PixDiT_T2I expects [B,L,D] — squeeze here
31
+ self.model_fn = lambda x, t, y: model_fn(x, t, y.squeeze(1) if y.dim() == 4 else y)
32
+ self.cfg = cfg
33
+ self.flow_shift = flow_shift
34
+
35
+ @torch.no_grad()
36
+ def sample(
37
+ self,
38
+ noise: torch.Tensor, # [B, 3, H, W] Gaussian noise
39
+ cond: torch.Tensor, # [B, 300, 2304]
40
+ uncond: torch.Tensor, # [B, 300, 2304]
41
+ steps: int = 20,
42
+ ) -> torch.Tensor:
43
+ """Returns denoised image tensor [B, 3, H, W] in [-1, 1]."""
44
+ # DPMS expects [B, 1, L, D]
45
+ cond_4d = cond.unsqueeze(1)
46
+ uncond_4d = uncond.unsqueeze(1)
47
+
48
+ dpm = DPMS(
49
+ self.model_fn,
50
+ condition=cond_4d,
51
+ uncondition=uncond_4d,
52
+ cfg_scale=self.cfg,
53
+ model_type="flow",
54
+ schedule="FLOW",
55
+ guidance_type="classifier-free",
56
+ interval_guidance=[0, 1],
57
+ )
58
+ return dpm.sample(
59
+ noise,
60
+ steps=steps,
61
+ order=2,
62
+ skip_type="time_uniform_flow",
63
+ method="multistep",
64
+ flow_shift=self.flow_shift,
65
+ )
pixeldit/text_encoder_gemma.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemma-2-2B text encoder for PixelDiT.
3
+ Handles chi_prompt prefix + select_index to match training exactly.
4
+
5
+ Usage:
6
+ from pixeldit.text_encoder_gemma import GemmaEncoder
7
+ enc = GemmaEncoder()
8
+ cond = enc.encode(["a dragon at sunset"]) # [1, 300, 2304]
9
+ null = enc.encode_null(1) # [1, 300, 2304]
10
+ """
11
+
12
+ import torch
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+
15
+ _GEMMA_ID = "Efficient-Large-Model/gemma-2-2b-it"
16
+ _GEMMA_DIM = 2304
17
+ _TXT_MAX = 300
18
+
19
+ _CHI_PROMPT = "\n".join([
20
+ 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:',
21
+ '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.',
22
+ '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.',
23
+ 'Here are examples of how to transform or refine prompts:',
24
+ '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.',
25
+ '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.',
26
+ 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:',
27
+ 'User Prompt: ',
28
+ ])
29
+ _SELECT_IDX = [0] + list(range(-(_TXT_MAX - 1), 0))
30
+
31
+
32
+ class GemmaEncoder:
33
+ def __init__(
34
+ self,
35
+ model_id=_GEMMA_ID,
36
+ output_device="cuda",
37
+ output_dtype=torch.bfloat16,
38
+ ):
39
+ self.output_device = torch.device(output_device)
40
+ self.output_dtype = output_dtype
41
+
42
+ print(f"[GemmaEncoder] loading {model_id} (CPU)")
43
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
44
+ self.tokenizer.padding_side = "right"
45
+ self._model = (
46
+ AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
47
+ .get_decoder().eval()
48
+ )
49
+ self._num_chi_tokens = len(self.tokenizer.encode(_CHI_PROMPT))
50
+ print("[GemmaEncoder] ready")
51
+
52
+ @torch.no_grad()
53
+ def encode(self, texts: list[str]) -> torch.Tensor:
54
+ """Returns [B, 300, 2304]."""
55
+ texts_full = [_CHI_PROMPT + t for t in texts]
56
+ max_len = self._num_chi_tokens + _TXT_MAX - 2
57
+ tok = self.tokenizer(
58
+ texts_full, max_length=max_len,
59
+ padding="max_length", truncation=True, return_tensors="pt",
60
+ )
61
+ emb = self._model(
62
+ input_ids=tok.input_ids,
63
+ attention_mask=tok.attention_mask,
64
+ ).last_hidden_state
65
+ emb = emb[:, _SELECT_IDX, :]
66
+ return emb.to(self.output_device).to(self.output_dtype)
67
+
68
+ @torch.no_grad()
69
+ def encode_null(self, batch_size: int) -> torch.Tensor:
70
+ """Returns [B, 300, 2304] for empty string (CFG unconditional)."""
71
+ tok = self.tokenizer(
72
+ [""] * batch_size, max_length=_TXT_MAX,
73
+ padding="max_length", truncation=True, return_tensors="pt",
74
+ )
75
+ emb = self._model(
76
+ input_ids=tok.input_ids,
77
+ attention_mask=tok.attention_mask,
78
+ ).last_hidden_state
79
+ return emb.to(self.output_device).to(self.output_dtype)
pixeldit/text_encoder_qwen.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Qwen3-2B text encoder for PixelDiT.
3
+ Requires a trained projection (train_qwen_proj.py) to map 2048→2304.
4
+
5
+ Usage:
6
+ from pixeldit.text_encoder_qwen import QwenEncoder
7
+ enc = QwenEncoder(proj_path="pixeldit/qwen_proj.pt")
8
+ cond = enc.encode(["a dragon at sunset"]) # [1, 300, 2304]
9
+ null = enc.encode_null(1) # [1, 300, 2304]
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from transformers import AutoTokenizer, AutoModel
15
+
16
+ _QWEN_ID = "Qwen/Qwen3-2B"
17
+ _QWEN_DIM = 2048
18
+ _GEMMA_DIM = 2304
19
+ _TXT_MAX = 300
20
+
21
+
22
+ class QwenEncoder:
23
+ def __init__(
24
+ self,
25
+ model_id=_QWEN_ID,
26
+ proj_path=None, # path to trained qwen_proj.pt
27
+ output_device="cuda",
28
+ output_dtype=torch.bfloat16,
29
+ ):
30
+ self.output_device = torch.device(output_device)
31
+ self.output_dtype = output_dtype
32
+
33
+ print(f"[QwenEncoder] loading {model_id} (CPU)")
34
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
35
+ self.tokenizer.padding_side = "right"
36
+ self._model = AutoModel.from_pretrained(model_id, torch_dtype=torch.float32).eval()
37
+
38
+ self.proj = nn.Linear(_QWEN_DIM, _GEMMA_DIM, bias=False)
39
+ if proj_path:
40
+ sd = torch.load(proj_path, map_location="cpu", weights_only=True)
41
+ self.proj.load_state_dict(sd)
42
+ print(f"[QwenEncoder] loaded projection: {proj_path}")
43
+ else:
44
+ with torch.no_grad():
45
+ w = torch.zeros(_GEMMA_DIM, _QWEN_DIM)
46
+ w[:_QWEN_DIM] = torch.eye(_QWEN_DIM)
47
+ self.proj.weight.copy_(w)
48
+ print("[QwenEncoder] projection: identity init — run train_qwen_proj.py for real quality")
49
+ self.proj = self.proj.to(self.output_device).to(output_dtype)
50
+ print("[QwenEncoder] ready")
51
+
52
+ @torch.no_grad()
53
+ def encode(self, texts: list[str]) -> torch.Tensor:
54
+ """Returns [B, 300, 2304]."""
55
+ tok = self.tokenizer(
56
+ texts, max_length=_TXT_MAX,
57
+ padding="max_length", truncation=True, return_tensors="pt",
58
+ )
59
+ emb = self._model(**tok).last_hidden_state
60
+ emb = emb.to(self.output_device).to(self.output_dtype)
61
+ return self.proj(emb)
62
+
63
+ @torch.no_grad()
64
+ def encode_null(self, batch_size: int) -> torch.Tensor:
65
+ """Returns [B, 300, 2304] for empty string (CFG unconditional)."""
66
+ tok = self.tokenizer(
67
+ [""] * batch_size, max_length=_TXT_MAX,
68
+ padding="max_length", truncation=True, return_tensors="pt",
69
+ )
70
+ emb = self._model(**tok).last_hidden_state
71
+ emb = emb.to(self.output_device).to(self.output_dtype)
72
+ return self.proj(emb)