concauu commited on
Commit
20290bd
·
verified ·
1 Parent(s): c6f2fae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -1
app.py CHANGED
@@ -63,11 +63,45 @@ class TextProjection(torch.nn.Module):
63
  def forward(self, x):
64
  return self.proj(x.to(dtype))
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Initialize pipeline components
67
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
68
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
69
  # Custom pipeline with T5 support
70
- pipe = DiffusionPipeline.from_pretrained(
71
  "black-forest-labs/FLUX.1-dev",
72
  text_encoder=t5_text_encoder,
73
  tokenizer=t5_tokenizer,
 
63
  def forward(self, x):
64
  return self.proj(x.to(dtype))
65
 
66
+ # Add this override to your existing pipeline setup
67
+ class T5FluxPipeline(FluxPipeline):
68
+ def _get_clip_prompt_embeds(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
69
+ """Modified to work with T5 outputs"""
70
+ # Get T5 embeddings
71
+ text_inputs = self.tokenizer(
72
+ prompt,
73
+ padding="max_length",
74
+ max_length=512,
75
+ truncation=True,
76
+ return_tensors="pt",
77
+ ).to(device)
78
+
79
+ text_outputs = self.text_encoder(**text_inputs)
80
+ prompt_embeds = text_outputs.last_hidden_state
81
+
82
+ # Use mean pooling instead of CLIP's pooler_output
83
+ pooled_prompt_embeds = prompt_embeds.mean(dim=1)
84
+
85
+ # Expand for batch
86
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
87
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
88
+
89
+ # Handle guidance
90
+ if do_classifier_free_guidance:
91
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
92
+ negative_pooled = torch.zeros_like(pooled_prompt_embeds)
93
+
94
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
95
+ pooled_prompt_embeds = torch.cat([negative_pooled, pooled_prompt_embeds])
96
+
97
+ return prompt_embeds, pooled_prompt_embeds
98
+
99
+
100
  # Initialize pipeline components
101
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
102
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
103
  # Custom pipeline with T5 support
104
+ pipe = T5FluxPipeline.from_pretrained(
105
  "black-forest-labs/FLUX.1-dev",
106
  text_encoder=t5_text_encoder,
107
  tokenizer=t5_tokenizer,