concauu commited on
Commit
2b19647
·
verified ·
1 Parent(s): 0e160be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -97
app.py CHANGED
@@ -14,65 +14,11 @@ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_
14
  from io import BytesIO
15
  import base64
16
  from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
17
- ###
18
- # Step 2: Modified pipeline class with proper component registration
19
- class T5FluxPipeline(DiffusionPipeline):
20
- def __init__(self, text_encoder, tokenizer, vae, unet, scheduler):
21
- super().__init__()
22
- self.device = device
23
- self.dtype = dtype
24
- self.register_modules(
25
- text_encoder=text_encoder,
26
- tokenizer=tokenizer,
27
- vae=vae,
28
- unet=unet,
29
- scheduler=scheduler
30
- )
31
- self.text_projection = torch.nn.Linear(768, 4096).to(device=device, dtype=dtype)
32
- torch.nn.init.normal_(self.text_projection.weight, std=0.02)
33
- torch.nn.init.zeros_(self.text_projection.bias)
34
-
35
- def encode_prompt(self, prompt, device, num_images_per_prompt=1,
36
- do_classifier_free_guidance=False, negative_prompt=None):
37
- text_inputs = self.tokenizer(
38
- prompt,
39
- padding="max_length",
40
- max_length=512,
41
- truncation=True,
42
- return_tensors="pt",
43
- ).to(device)
44
-
45
- text_embeddings = self.text_encoder(**text_inputs).last_hidden_state
46
- text_embeddings = self.text_projection(text_embeddings)
47
- pooled_embeddings = text_embeddings.mean(dim=1)
48
-
49
- if do_classifier_free_guidance:
50
- uncond_input = self.tokenizer(
51
- [negative_prompt] if negative_prompt else [""],
52
- padding="max_length",
53
- max_length=512,
54
- truncation=True,
55
- return_tensors="pt",
56
- ).to(device)
57
-
58
- uncond_embeddings = self.text_projection(
59
- self.text_encoder(**uncond_input).last_hidden_state
60
- )
61
- uncond_pooled = uncond_embeddings.mean(dim=1)
62
-
63
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
64
- pooled_embeddings = torch.cat([uncond_pooled, pooled_embeddings])
65
-
66
- text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
67
- pooled_embeddings = pooled_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
68
-
69
- return text_embeddings, pooled_embeddings, text_inputs.input_ids
70
-
71
- ###
72
-
73
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
74
  dtype = torch.bfloat16
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
76
  def get_hf_token(encrypted_token):
77
  # Retrieve the decryption key from an environment variable
78
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
@@ -95,38 +41,72 @@ decrypted_token = get_hf_token("gAAAAABn3GfShExoJd50nau3B5ZJNiQ9dRD1ACO3XXMwVaIQ
95
  login(token=decrypted_token)
96
  groq_client = Groq(api_key="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")
97
 
98
- vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
99
- unet = UNet2DConditionModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="unet")
100
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
102
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
103
- # Initialize pipeline with correct config
104
- try:
105
- t5_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
106
- t5_text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-base").to(device, dtype=dtype)
107
-
108
- pipe = DiffusionPipeline.from_pretrained(
109
- "black-forest-labs/FLUX.1-dev",
110
- custom_pipeline=T5FluxPipeline,
111
- text_encoder=t5_text_encoder,
112
- tokenizer=t5_tokenizer,
113
- torch_dtype=dtype,
114
- safety_checker=None,
115
- requires_safety_checker=False
116
- ).to(device)
117
-
118
- pipe.text_projection = pipe.text_projection.to(device, dtype=dtype)
119
- torch.cuda.empty_cache()
120
 
121
- except Exception as e:
122
- print(f"Model loading error: {str(e)}")
123
- raise
124
 
125
  MAX_SEED = np.iinfo(np.int32).max
126
  MAX_IMAGE_SIZE = 2048
127
 
128
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
 
 
 
 
130
 
131
  # History functions
132
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
@@ -182,28 +162,27 @@ def create_history_html(history):
182
 
183
 
184
  @spaces.GPU(duration=75)
185
- def infer(prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
 
186
  if randomize_seed:
187
  seed = random.randint(0, MAX_SEED)
188
  generator = torch.Generator().manual_seed(seed)
189
 
190
- # Force PIL output
191
- final_image = None
192
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
193
- prompt=prompt,
194
- guidance_scale=guidance_scale,
195
- num_inference_steps=num_inference_steps,
196
- width=width,
197
- height=height,
198
- generator=generator,
199
- output_type="pil",
200
- good_vae=good_vae,
201
- ):
202
- final_image = img # Keep updating until we get the final image
203
- yield img, seed # Live preview
204
 
205
- # Return the final image explicitly
206
- yield final_image, seed
 
 
 
 
 
 
 
 
 
207
 
208
 
209
 
 
14
  from io import BytesIO
15
  import base64
16
  from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = '120'
18
  dtype = torch.bfloat16
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+
22
  def get_hf_token(encrypted_token):
23
  # Retrieve the decryption key from an environment variable
24
  key = "K4FlQbffvTcDxT2FIhrOPV1eue6ia45FFR3kqp2hHbM="
 
41
  login(token=decrypted_token)
42
  groq_client = Groq(api_key="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")
43
 
44
+
45
+ # Load T5 components for longer context
46
+ t5_tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-base", model_max_length=512)
47
+ t5_text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-base").to(device, dtype=dtype)
48
+
49
+ # Add projection layer to match CLIP's embedding dimensions
50
+ class TextProjection(torch.nn.Module):
51
+ def __init__(self):
52
+ super().__init__()
53
+ self.proj = torch.nn.Linear(768, 768) # T5-base to CLIP dimensions
54
+ torch.nn.init.normal_(self.proj.weight, std=0.02)
55
+
56
+ def forward(self, x):
57
+ return self.proj(x.to(dtype))
58
+
59
+ # Initialize pipeline components
60
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
61
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
62
+ # Custom pipeline with T5 support
63
+ pipe = DiffusionPipeline.from_pretrained(
64
+ "black-forest-labs/FLUX.1-dev",
65
+ text_encoder=t5_text_encoder,
66
+ tokenizer=t5_tokenizer,
67
+ torch_dtype=dtype,
68
+ vae=taef1,
69
+ safety_checker=None
70
+ ).to(device)
 
 
 
 
 
 
 
 
71
 
72
+ # Add projection layer to pipeline
73
+ pipe.text_projection = TextProjection().to(device, dtype=dtype)
74
+ torch.cuda.empty_cache()
75
 
76
  MAX_SEED = np.iinfo(np.int32).max
77
  MAX_IMAGE_SIZE = 2048
78
 
79
+ # Monkey-patch the text encoding method
80
+ def custom_encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None):
81
+ text_inputs = self.tokenizer(
82
+ prompt,
83
+ padding="max_length",
84
+ max_length=512,
85
+ truncation=True,
86
+ return_tensors="pt",
87
+ ).to(device)
88
+
89
+ text_embeddings = self.text_encoder(**text_inputs).last_hidden_state
90
+ text_embeddings = self.text_projection(text_embeddings)
91
+
92
+ if do_classifier_free_guidance:
93
+ uncond_input = self.tokenizer(
94
+ [negative_prompt] if negative_prompt else [""],
95
+ padding="max_length",
96
+ max_length=512,
97
+ truncation=True,
98
+ return_tensors="pt",
99
+ ).to(device)
100
+
101
+ uncond_embeddings = self.text_projection(
102
+ self.text_encoder(**uncond_input).last_hidden_state
103
+ )
104
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
105
 
106
+ return text_embeddings
107
+
108
+ pipe._encode_prompt = custom_encode_prompt.__get__(pipe)
109
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
110
 
111
  # History functions
112
  def append_to_history(image, prompt, seed, width, height, guidance_scale, steps, history):
 
162
 
163
 
164
  @spaces.GPU(duration=75)
165
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
166
+ guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
167
  if randomize_seed:
168
  seed = random.randint(0, MAX_SEED)
169
  generator = torch.Generator().manual_seed(seed)
170
 
171
+ # Truncate prompt to 512 tokens if needed
172
+ tokens = t5_tokenizer.encode(prompt)[:512]
173
+ processed_prompt = t5_tokenizer.decode(tokens, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
176
+ prompt=processed_prompt,
177
+ guidance_scale=guidance_scale,
178
+ num_inference_steps=num_inference_steps,
179
+ width=width,
180
+ height=height,
181
+ generator=generator,
182
+ output_type="pil",
183
+ good_vae=good_vae,
184
+ ):
185
+ yield img, seed
186
 
187
 
188