madtune commited on
Commit
33047ba
·
verified ·
1 Parent(s): a4adb91

Delete pixeldit/text_encoder_qwen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixeldit/text_encoder_qwen.py +0 -72
pixeldit/text_encoder_qwen.py DELETED
@@ -1,72 +0,0 @@
1
- """
2
- Qwen3-2B text encoder for PixelDiT.
3
- Requires a trained projection (train_qwen_proj.py) to map 2048→2304.
4
-
5
- Usage:
6
- from pixeldit.text_encoder_qwen import QwenEncoder
7
- enc = QwenEncoder(proj_path="pixeldit/qwen_proj.pt")
8
- cond = enc.encode(["a dragon at sunset"]) # [1, 300, 2304]
9
- null = enc.encode_null(1) # [1, 300, 2304]
10
- """
11
-
12
- import torch
13
- import torch.nn as nn
14
- from transformers import AutoTokenizer, AutoModel
15
-
16
- _QWEN_ID = "Qwen/Qwen3-2B"
17
- _QWEN_DIM = 2048
18
- _GEMMA_DIM = 2304
19
- _TXT_MAX = 300
20
-
21
-
22
- class QwenEncoder:
23
- def __init__(
24
- self,
25
- model_id=_QWEN_ID,
26
- proj_path=None, # path to trained qwen_proj.pt
27
- output_device="cuda",
28
- output_dtype=torch.bfloat16,
29
- ):
30
- self.output_device = torch.device(output_device)
31
- self.output_dtype = output_dtype
32
-
33
- print(f"[QwenEncoder] loading {model_id} (CPU)")
34
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
35
- self.tokenizer.padding_side = "right"
36
- self._model = AutoModel.from_pretrained(model_id, torch_dtype=torch.float32).eval()
37
-
38
- self.proj = nn.Linear(_QWEN_DIM, _GEMMA_DIM, bias=False)
39
- if proj_path:
40
- sd = torch.load(proj_path, map_location="cpu", weights_only=True)
41
- self.proj.load_state_dict(sd)
42
- print(f"[QwenEncoder] loaded projection: {proj_path}")
43
- else:
44
- with torch.no_grad():
45
- w = torch.zeros(_GEMMA_DIM, _QWEN_DIM)
46
- w[:_QWEN_DIM] = torch.eye(_QWEN_DIM)
47
- self.proj.weight.copy_(w)
48
- print("[QwenEncoder] projection: identity init — run train_qwen_proj.py for real quality")
49
- self.proj = self.proj.to(self.output_device).to(output_dtype)
50
- print("[QwenEncoder] ready")
51
-
52
- @torch.no_grad()
53
- def encode(self, texts: list[str]) -> torch.Tensor:
54
- """Returns [B, 300, 2304]."""
55
- tok = self.tokenizer(
56
- texts, max_length=_TXT_MAX,
57
- padding="max_length", truncation=True, return_tensors="pt",
58
- )
59
- emb = self._model(**tok).last_hidden_state
60
- emb = emb.to(self.output_device).to(self.output_dtype)
61
- return self.proj(emb)
62
-
63
- @torch.no_grad()
64
- def encode_null(self, batch_size: int) -> torch.Tensor:
65
- """Returns [B, 300, 2304] for empty string (CFG unconditional)."""
66
- tok = self.tokenizer(
67
- [""] * batch_size, max_length=_TXT_MAX,
68
- padding="max_length", truncation=True, return_tensors="pt",
69
- )
70
- emb = self._model(**tok).last_hidden_state
71
- emb = emb.to(self.output_device).to(self.output_dtype)
72
- return self.proj(emb)