madtune commited on
Commit
a4adb91
·
verified ·
1 Parent(s): 5204049

Delete pixeldit/text_encoder_gemma.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixeldit/text_encoder_gemma.py +0 -79
pixeldit/text_encoder_gemma.py DELETED
@@ -1,79 +0,0 @@
1
- """
2
- Gemma-2-2B text encoder for PixelDiT.
3
- Handles chi_prompt prefix + select_index to match training exactly.
4
-
5
- Usage:
6
- from pixeldit.text_encoder_gemma import GemmaEncoder
7
- enc = GemmaEncoder()
8
- cond = enc.encode(["a dragon at sunset"]) # [1, 300, 2304]
9
- null = enc.encode_null(1) # [1, 300, 2304]
10
- """
11
-
12
- import torch
13
- from transformers import AutoTokenizer, AutoModelForCausalLM
14
-
15
- _GEMMA_ID = "Efficient-Large-Model/gemma-2-2b-it"
16
- _GEMMA_DIM = 2304
17
- _TXT_MAX = 300
18
-
19
- _CHI_PROMPT = "\n".join([
20
- 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:',
21
- '- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.',
22
- '- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.',
23
- 'Here are examples of how to transform or refine prompts:',
24
- '- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.',
25
- '- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.',
26
- 'Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:',
27
- 'User Prompt: ',
28
- ])
29
- _SELECT_IDX = [0] + list(range(-(_TXT_MAX - 1), 0))
30
-
31
-
32
- class GemmaEncoder:
33
- def __init__(
34
- self,
35
- model_id=_GEMMA_ID,
36
- output_device="cuda",
37
- output_dtype=torch.bfloat16,
38
- ):
39
- self.output_device = torch.device(output_device)
40
- self.output_dtype = output_dtype
41
-
42
- print(f"[GemmaEncoder] loading {model_id} (CPU)")
43
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
44
- self.tokenizer.padding_side = "right"
45
- self._model = (
46
- AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
47
- .get_decoder().eval()
48
- )
49
- self._num_chi_tokens = len(self.tokenizer.encode(_CHI_PROMPT))
50
- print("[GemmaEncoder] ready")
51
-
52
- @torch.no_grad()
53
- def encode(self, texts: list[str]) -> torch.Tensor:
54
- """Returns [B, 300, 2304]."""
55
- texts_full = [_CHI_PROMPT + t for t in texts]
56
- max_len = self._num_chi_tokens + _TXT_MAX - 2
57
- tok = self.tokenizer(
58
- texts_full, max_length=max_len,
59
- padding="max_length", truncation=True, return_tensors="pt",
60
- )
61
- emb = self._model(
62
- input_ids=tok.input_ids,
63
- attention_mask=tok.attention_mask,
64
- ).last_hidden_state
65
- emb = emb[:, _SELECT_IDX, :]
66
- return emb.to(self.output_device).to(self.output_dtype)
67
-
68
- @torch.no_grad()
69
- def encode_null(self, batch_size: int) -> torch.Tensor:
70
- """Returns [B, 300, 2304] for empty string (CFG unconditional)."""
71
- tok = self.tokenizer(
72
- [""] * batch_size, max_length=_TXT_MAX,
73
- padding="max_length", truncation=True, return_tensors="pt",
74
- )
75
- emb = self._model(
76
- input_ids=tok.input_ids,
77
- attention_mask=tok.attention_mask,
78
- ).last_hidden_state
79
- return emb.to(self.output_device).to(self.output_dtype)