ppbrown commited on
Commit
8f28c81
·
verified ·
1 Parent(s): 58ef10f

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +144 -74
pipeline.py CHANGED
@@ -1,115 +1,185 @@
1
- # pipeline.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- # subclass SD pipeline to replace CLIP-L with T5
 
 
4
 
5
- import torch
6
- import torch.nn as nn
7
- from transformers import AutoTokenizer, AutoModel
8
- from transformers import T5Tokenizer, T5EncoderModel
9
- from diffusers import StableDiffusionPipeline
10
- from diffusers.utils import logging
11
 
12
- logger = logging.get_logger(__name__)
13
 
14
- T5_NAME="mcmonkey/google_t5-v1_1-xxl_encoderonly"
 
 
15
 
16
  class LinearWithDtype(nn.Linear):
17
  @property
18
  def dtype(self):
19
  return self.weight.dtype
20
 
21
- class StableDiffusionT5Pipeline(StableDiffusionPipeline):
22
 
23
- # override this so we can auto-init text_encoder
24
- #_optional_components = ["safety_checker", "feature_extractor", "image_encoder", "text_encoder"]
 
 
 
 
25
 
26
- # t5_projection not really optional, but needed it here to stop internal whining
27
- _optional_components = StableDiffusionPipeline._optional_components + ["text_encoder", "t5_projection"]
 
 
28
 
29
  def __init__(
30
  self,
31
- vae,
32
- text_encoder,
33
- tokenizer,
34
- unet,
35
- scheduler,
36
- safety_checker=None,
37
- feature_extractor=None,
38
- image_encoder=None,
39
- requires_safety_checker=True,
40
- t5_projection: LinearWithDtype=None,
 
41
  ):
42
- self.tokenizer = (
43
- tokenizer
44
- if tokenizer is not None
45
- else T5Tokenizer.from_pretrained(T5_NAME,torch_dtype=unet.dtype)
46
- )
47
 
 
 
 
 
 
48
 
49
- if text_encoder is None:
50
- self.text_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype)
 
51
  else:
52
- self.text_encoder = text_encoder
 
 
 
 
 
 
 
 
 
53
 
54
- super().__init__(
55
  vae=vae,
56
- tokenizer=self.tokenizer,
57
- text_encoder=self.text_encoder,
58
  unet=unet,
59
  scheduler=scheduler,
60
- safety_checker=safety_checker,
61
- feature_extractor=feature_extractor,
 
 
62
  image_encoder=image_encoder,
63
- requires_safety_checker=requires_safety_checker,
 
 
 
 
 
 
 
 
 
64
  )
65
 
66
- if t5_projection is None:
67
- self.t5_projection = LinearWithDtype(4096, 768).to(unet.device, dtype=unet.dtype)
68
- else:
69
- self.t5_projection = t5_projection
70
 
71
- # Ensure everything is properly registered for to("cuda")
72
- self.register_modules(t5_projection=self.t5_projection)
 
73
 
 
 
 
 
 
 
74
  def encode_prompt(
75
  self,
76
  prompt,
77
- device,
78
- num_images_per_prompt,
79
- do_classifier_free_guidance,
80
- negative_prompt = None,
81
- prompt_embeds = None,
82
- negative_prompt_embeds = None,
83
- lora_scale = None,
84
- clip_skip = None, # ignore with T5!!
85
- **kwargs,
86
  ):
87
-
88
- def _tok(text):
89
- out = self.tokenizer(
 
 
 
 
 
 
 
 
 
 
90
  text,
91
  return_tensors="pt",
92
  padding="max_length",
93
  max_length=self.tokenizer.model_max_length,
94
  truncation=True,
95
- )
96
- return out.input_ids.to(device=device, dtype=torch.long), out.attention_mask.to(device)
97
-
98
- pos_ids, pos_mask = _tok(prompt)
99
- pos_hidden = self.text_encoder(pos_ids, attention_mask=pos_mask).last_hidden_state
100
- pos_embeds = self.t5_projection(pos_hidden)
101
 
102
- if do_classifier_free_guidance:
103
- neg_prompt = negative_prompt if negative_prompt is not None else ""
104
- neg_ids, neg_mask = _tok(neg_prompt)
105
- neg_hidden = self.text_encoder(neg_ids, attention_mask=neg_mask).last_hidden_state
106
- neg_embeds = self.t5_projection(neg_hidden)
107
 
108
- # Expand for multiple images per prompt
109
- pos_embeds = pos_embeds.repeat_interleave(num_images_per_prompt, dim=0)
110
- neg_embeds = neg_embeds.repeat_interleave(num_images_per_prompt, dim=0)
111
 
112
- return [neg_embeds, pos_embeds]
 
 
 
 
 
 
 
 
 
113
  else:
114
- pos_embeds = pos_embeds.repeat_interleave(num_images_per_prompt, dim=0)
115
- return pos_embeds
 
 
1
+ # Copyright Philip Brown, ppbrown@github
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Note: At this time, the intent is to use the T5 encoder mentioned
16
+ # below, with zero changes.
17
+ # Therefore, the model deliberately does not store the T5 encoder model bytes,
18
+ # (Since they are not unique!)
19
+ # but instead takes advantage of huggingface hub cache loading
20
+
21
+ T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
22
+
23
+
24
+ from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
25
+ from transformers import T5Tokenizer, T5EncoderModel
26
+ from transformers import (
27
+ CLIPImageProcessor,
28
+ CLIPTextModel,
29
+ CLIPTextModelWithProjection,
30
+ CLIPTokenizer,
31
+ CLIPVisionModelWithProjection,
32
+ )
33
 
34
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
35
+ from diffusers.schedulers import KarrasDiffusionSchedulers
36
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
37
 
 
 
 
 
 
 
38
 
39
+ from typing import Optional
40
 
41
+ import torch.nn as nn, torch, types
42
+
43
+ import torch.nn as nn
44
 
45
  class LinearWithDtype(nn.Linear):
46
  @property
47
  def dtype(self):
48
  return self.weight.dtype
49
 
 
50
 
51
+ class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
52
+ _expected_modules = [
53
+ "vae", "unet", "scheduler", "tokenizer",
54
+ "image_encoder", "feature_extractor",
55
+ "t5_encoder", "t5_projection", "t5_pooled_projection",
56
+ ]
57
 
58
+ _optional_components = [
59
+ "image_encoder", "feature_extractor",
60
+ "t5_encoder", "t5_projection", "t5_pooled_projection",
61
+ ]
62
 
63
  def __init__(
64
  self,
65
+ vae: AutoencoderKL,
66
+ unet: UNet2DConditionModel,
67
+ scheduler: KarrasDiffusionSchedulers,
68
+ tokenizer: T5Tokenizer,
69
+ t5_encoder=None,
70
+ t5_projection=None,
71
+ t5_pooled_projection=None,
72
+ image_encoder: CLIPVisionModelWithProjection = None,
73
+ feature_extractor: CLIPImageProcessor = None,
74
+ force_zeros_for_empty_prompt: bool = True,
75
+ add_watermarker: Optional[bool] = None,
76
  ):
77
+ DiffusionPipeline.__init__(self)
 
 
 
 
78
 
79
+ if t5_encoder is None:
80
+ self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME,
81
+ torch_dtype=unet.dtype)
82
+ else:
83
+ self.t5_encoder = t5_encoder
84
 
85
+ # ----- build T5 4096 => 2048 dim projection -----
86
+ if t5_projection is None:
87
+ self.t5_projection = LinearWithDtype(4096, 2048) # trainable
88
  else:
89
+ self.t5_projection = t5_projection
90
+ self.t5_projection.to(dtype=unet.dtype)
91
+ # ----- build T5 4096 => 1280 dim projection -----
92
+ if t5_pooled_projection is None:
93
+ self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable
94
+ else:
95
+ self.t5_pooled_projection = t5_pooled_projection
96
+ self.t5_pooled_projection.to(dtype=unet.dtype)
97
+
98
+ print("dtype of Linear is ",self.t5_projection.dtype)
99
 
100
+ self.register_modules(
101
  vae=vae,
 
 
102
  unet=unet,
103
  scheduler=scheduler,
104
+ tokenizer=tokenizer,
105
+ t5_encoder=self.t5_encoder,
106
+ t5_projection=self.t5_projection,
107
+ t5_pooled_projection=self.t5_pooled_projection,
108
  image_encoder=image_encoder,
109
+ feature_extractor=feature_extractor,
110
+ )
111
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
112
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
113
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
114
+
115
+ self.default_sample_size = (
116
+ self.unet.config.sample_size
117
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
118
+ else 128
119
  )
120
 
121
+ self.watermark = None
 
 
 
122
 
123
+ # Parts of original SDXL class complain if these attributes are not
124
+ # at least PRESENT
125
+ self.text_encoder = self.text_encoder_2 = None
126
 
127
+ # ------------------------------------------------------------------------
128
+ # Encode a text prompt
129
+ # Use + 4096 => 2048 projection for standard embeds, but
130
+ # 4096 => 1280 for pooled embeds, because that's what the unet requires.
131
+ # Returns exactly four tensors in the order SDXL's __call__ expects.
132
+ # ------------------------------------------------------------------------
133
  def encode_prompt(
134
  self,
135
  prompt,
136
+ num_images_per_prompt: int = 1,
137
+ do_classifier_free_guidance: bool = True,
138
+ negative_prompt: str | None = None,
139
+ **_,
 
 
 
 
 
140
  ):
141
+ """
142
+ Returns
143
+ -------
144
+ prompt_embeds : Tensor [B, T, 2048]
145
+ negative_prompt_embeds : Tensor [B, T, 2048] | None
146
+ pooled_prompt_embeds : Tensor [B, 1280]
147
+ negative_pooled_prompt_embeds: Tensor [B, 1280] | None
148
+ where B = batch * num_images_per_prompt
149
+ """
150
+
151
+ # --- helper to tokenize on the pipeline's device ----------------
152
+ def _tok(text: str):
153
+ tok_out = self.tokenizer(
154
  text,
155
  return_tensors="pt",
156
  padding="max_length",
157
  max_length=self.tokenizer.model_max_length,
158
  truncation=True,
159
+ ).to(self.device)
160
+ return tok_out.input_ids, tok_out.attention_mask
 
 
 
 
161
 
162
+ # ---------- positive stream -------------------------------------
163
+ ids, mask = _tok(prompt)
164
+ h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
165
+ tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
166
+ pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280]
167
 
168
+ # expand for multiple images per prompt
169
+ tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
170
+ pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)
171
 
172
+ # ---------- negative / CFG stream --------------------------------
173
+ if do_classifier_free_guidance:
174
+ neg_text = "" if negative_prompt is None else negative_prompt
175
+ ids_n, mask_n = _tok(neg_text)
176
+ h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state
177
+ tok_neg = self.t5_projection(h_neg)
178
+ pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1))
179
+
180
+ tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
181
+ pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)
182
  else:
183
+ tok_neg = pool_neg = None
184
+
185
+ return tok_pos, tok_neg, pool_pos, pool_neg