| | --- |
| | license: apache-2.0 |
| | datasets: |
| | - AbstractPhil/human-templated-captions-1b |
| | base_model: |
| | - google-t5/t5-small |
| | - openai/clip-vit-large-patch14 |
| | pipeline_tag: any-to-any |
| | --- |
| | |
| |
|
| | This is a shunt that takes in the t5-small and the vit-h-14 simultaneously. |
| |
|
| | The t5-small is used as a conditioning factor for normalization and guidance. |
| |
|
| | There are many possible toggles and many variations for this shunt to be used. |
| |
|
| | The only one I hooked up is the basic tool meant for simple text encoder guidance, then I shunted it into clip_embeds for a test - only to see it fall apart. |
| | |
| | The results that worked with diffusers without a headache ended up being prompt_encode overriding with a monkey patch. |
| |
|
| |
|
| | Drag and drop into colab and generate some sdxl images with it. Two nodes; one above the generator |
| |
|
| | Fiddle with the taps and mess with the settings to add additional or reduce guidance from the T5-small variations with your clip_l. |
| | |
| | |
| | ``` |
| | import safetensors.torch as st |
| | import torch |
| | from diffusers import StableDiffusionXLPipeline |
| | from transformers import T5TokenizerFast, T5EncoderModel |
| | |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | from tqdm.auto import tqdm |
| | |
| | # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| | # β Two-Stream Shunt Adapter |
| | # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| | class TwoStreamShuntAdapter(nn.Module): |
| | """ |
| | Cross-attentive adapter that aligns T5 and CLIP token streams. |
| | |
| | Returns: |
| | anchor : (B, Lc, clip_dim) |
| | delta : (B, Lc, clip_dim) |
| | log_sigma : (B, Lc, clip_dim) β log Ο, always finite |
| | attn_t2c : (B, heads, Lt, Lc) |
| | attn_c2t : (B, heads, Lc, Lt) |
| | tau : (heads, 1, 1) β per-head threshold param |
| | g_pred : (B, 1) β guidance-scale prediction |
| | gate : (B, Lc, 1) β per-token gate β (0,1) |
| | """ |
| | |
| | def __init__( |
| | self, |
| | t5_dim: int = 512, |
| | clip_dim: int = 768, |
| | bottleneck: int = 256, |
| | heads: int = 8, |
| | tau_init: float = 0.1, |
| | max_guidance: float = 10.0, |
| | ): |
| | super().__init__() |
| | print("TwoStreamShuntAdapter init") |
| | self.heads = heads |
| | self.bneck = bottleneck |
| | self.max_guidance = max_guidance |
| | |
| | # projections |
| | self.proj_t5 = nn.Linear(t5_dim, bottleneck) |
| | self.proj_clip = nn.Linear(clip_dim, bottleneck) |
| | |
| | # cross-attention |
| | self.cross_t2c = nn.MultiheadAttention( |
| | bottleneck, heads, batch_first=True, dropout=0.1 |
| | ) |
| | self.cross_c2t = nn.MultiheadAttention( |
| | bottleneck, heads, batch_first=True, dropout=0.1 |
| | ) |
| | |
| | # head-wise Ο |
| | self.tau = nn.Parameter(torch.full((heads, 1, 1), tau_init)) |
| | |
| | # convolutional pocket residual (depth-wise) |
| | self.res1 = nn.Conv1d( |
| | bottleneck, bottleneck, 3, padding=1, groups=bottleneck |
| | ) |
| | self.res2 = nn.Conv1d( |
| | bottleneck, bottleneck, 3, padding=1, groups=bottleneck |
| | ) |
| | self.norm_res = nn.LayerNorm(bottleneck) |
| | |
| | # fusion + projections |
| | self.fuse = nn.Linear(2 * bottleneck, bottleneck) |
| | |
| | self.anchor_proj = nn.Sequential( |
| | nn.Linear(bottleneck, bottleneck), nn.GELU(), |
| | nn.Linear(bottleneck, clip_dim) |
| | ) |
| | self.delta_proj = nn.Sequential( |
| | nn.Linear(bottleneck, bottleneck), nn.GELU(), |
| | nn.Linear(bottleneck, clip_dim) |
| | ) |
| | self.logsig_proj = nn.Sequential( |
| | nn.Linear(bottleneck, bottleneck), nn.GELU(), |
| | nn.Linear(bottleneck, clip_dim) |
| | ) |
| | self.gate_proj = nn.Sequential( |
| | nn.Linear(bottleneck, bottleneck), nn.GELU(), |
| | nn.Linear(bottleneck, 1), nn.Sigmoid() |
| | ) |
| | self.guidance_proj = nn.Sequential( |
| | nn.LayerNorm(bottleneck), nn.Linear(bottleneck, 1), nn.Sigmoid() |
| | ) |
| | |
| | def load_state_dict(self, args, **kwargs): |
| | # remove _orig_mod from state dict before applying. |
| | state_dict = {k.replace("_orig_mod.", ""): v for k, v in args.items()} |
| | super().load_state_dict(state_dict, **kwargs) |
| | |
| | def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor): |
| | print("π£ SHUNT FORWARD CALLED") |
| | |
| | B, Lt, _ = t5_seq.size() |
| | _, Lc, _ = clip_seq.size() |
| | |
| | # 1) project into bottleneck |
| | t5_b = self.proj_t5(t5_seq) # (B, Lt, b) |
| | clip_b = self.proj_clip(clip_seq) # (B, Lc, b) |
| | |
| | # 2) cross-attention |
| | t2c, attn_t2c = self.cross_t2c( |
| | t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False |
| | ) |
| | c2t, attn_c2t = self.cross_c2t( |
| | clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False |
| | ) |
| | |
| | # 3) convolutional pocket on T5βCLIP |
| | x = t2c.transpose(1, 2) # (B, b, Lt) |
| | x = F.gelu(self.res1(x)) |
| | x = F.gelu(self.res2(x)).transpose(1, 2) # (B, Lt, b) |
| | pocket = self.norm_res(t2c + x) # (B, Lt, b) |
| | |
| | # 4) fuse pocket avg with C2T |
| | pocket_mean = pocket.mean(1, keepdim=True).expand(-1, Lc, -1) |
| | h = F.gelu(self.fuse(torch.cat([pocket_mean, c2t], -1))) # (B, Lc, b) |
| | |
| | # 5) outputs |
| | anchor = self.anchor_proj(h) # (B,Lc,768) |
| | delta_mean = self.delta_proj(h) # (B,Lc,768) |
| | log_sigma = self.logsig_proj(h) # (B,Lc,768) |
| | gate = self.gate_proj(h) # (B,Lc,1) |
| | delta = delta_mean * gate # (B,Lc,768) |
| | |
| | g_tok = self.guidance_proj(h).squeeze(-1) # (B,Lc) |
| | g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance |
| | |
| | #print(anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate) |
| | |
| | return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate |
| | |
| | # --- 1. load pipeline ------------------------------------------------- |
| | pipe = StableDiffusionXLPipeline.from_pretrained( |
| | "stabilityai/stable-diffusion-xl-base-1.0", |
| | torch_dtype=torch.float16).to("cuda") |
| |
|
| | # --- 2. load tiny-T5 & shunt (fp32) ----------------------------------- |
| | t5_tok = T5TokenizerFast.from_pretrained("t5-small") |
| | t5_mod = T5EncoderModel.from_pretrained("t5-small").eval().to("cuda") |
| | shunt = TwoStreamShuntAdapter().float().eval().to("cuda") |
| | shunt.load_state_dict( st.load_file("/content/drive/MyDrive/t5-clip-l-shunts/vitl14_t5small_shunt_vanilla_final.safetensors") ) |
| | |
| | # --- 3. wrap encode_prompt once --------------------------------------- |
| | orig_encode = pipe.encode_prompt |
| |
|
| | config = { |
| | "strength": 1.0, |
| | "gate_gamma": 1.0, |
| | "tau_scale": 1.0, |
| | "guidance_gain": 1.0, |
| | "guidance_bias": 0.0 |
| | } |
| | |
| |
|
| | gen = torch.Generator(device="cuda").manual_seed(420) |
| | ``` |
| | |
| | Place this on another node so you don't reload over and over. |
| | |
| | ``` |
| | strength = 0 |
| | |
| | # the working version that can't be omitted, |
| | def stable_encode_prompt_shunted(self, *args, **kw): |
| | pe, ne, pool, npool = orig_encode(*args, **kw) # regular call |
| | |
| | # π split: first 768 dims are CLIP-L, rest 1280 are CLIP-G |
| | clipL, clipG = pe[..., :768], pe[..., 768:] |
| | |
| | # build T5 batch (handles CFG dup automatically because |
| | # encode_prompt already concatenated negative & positive if needed) |
| | bsz = clipL.shape[0] |
| | texts = ["tmp"] * bsz # dummy, we only care about hidden states |
| | t5_ids = t5_tok(texts, return_tensors="pt").input_ids.to("cuda") |
| | t5_seq = t5_mod(t5_ids).last_hidden_state # (B,L,512) |
| | |
| | # run adapter in fp32 |
| | delta = shunt(t5_seq.float(), clipL.float())[1] # second output is Ξ |
| | delta = delta * strength # << your strength knob |
| | clipL_shift = (clipL.float() + delta).to(clipL.dtype) |
| | |
| | pe_shifted = torch.cat([clipL_shift, clipG], dim=-1) |
| | return pe_shifted, ne, pool, npool |
| | #----------------------------------------------------------------------------------------- |
| | |
| | def encode_prompt_shunted(self, *a, **k): |
| | # 1) run the normal encoder with βstyleβ & βcontextβ already split |
| | pe, ne, pool, npool = orig_encode(*a, **k) # (B,77,2048) |
| | |
| | # 2) split CLIP-L / CLIP-G |
| | clipL, clipG = pe[..., :768], pe[..., 768:] |
| | |
| | # 3) build T5 on the *context* text (itβs in k['prompt_2']) |
| | t5_ids = t5_tok([k.get("prompt_2")], return_tensors="pt").input_ids.to(pe.device) |
| | t5_seq = t5_mod(t5_ids).last_hidden_state.float() |
| | |
| | # 4) shunt β Ξ (FP32 β back-cast) |
| | Ξ = shunt(t5_seq, clipL.float())[1].to(clipL.dtype) |
| | clipL_shift = clipL + Ξ * strength |
| | |
| | # 5) concatenate back |
| | pe_shift = torch.cat([clipL_shift, clipG], dim=-1) |
| | return pe_shift, ne, pool, npool |
| | |
| | pipe.encode_prompt = encode_prompt_shunted.__get__(pipe, type(pipe)) |
| | |
| | |
| | |
| | |
| | PROMPT = "a naturally lit and beautiful room with a photorealistic depiction of a woman" |
| | PROMPT_2 = "a realistic depiction of a woman sitting on a chair at a coffee shop sipping coffee, the environment is beautiful" |
| | NEG = "blurry, distorted, monochrome, greyscale, watermark" |
| | STEPS = 50 |
| | base_strength = 0.5 |
| | base_cfg = 7.5 |
| |
|
| |
|
| | for i in range(0, 4): |
| | strength = base_strength + (i * 0.25) |
| | cfg = base_cfg - (i * 0.25) |
| | img = pipe( |
| | PROMPT, |
| | prompt_2=PROMPT_2, |
| | negative_prompt=NEG, |
| | num_inference_steps=STEPS, |
| | cfg_scale=cfg, |
| | generator=torch.Generator(device="cuda").manual_seed(420) |
| | ).images[0] |
| | img.save(f"woman_cfg_{int(cfg*100)}_{int(strength*100)}.png") |
| | |
| | # --- 4. generate ------------------------------------------------------- |
| | #img = pipe( |
| | # PROMPT, |
| | # negative_prompt=NEG, |
| | # num_inference_steps=STEPS, |
| | # generator=torch.Generator(device="cuda").manual_seed(420) |
| | # ).images[0] |
| | #img.save("majestic_baseline.png")# |
| | # |
| | |
| | #strength = 0.25 |
| | ## --- 4. generate ------------------------------------------------------- |
| | #img = pipe( |
| | # PROMPT, |
| | # negative_prompt=NEG, |
| | # num_inference_steps=STEPS, |
| | # generator=torch.Generator(device="cuda").manual_seed(420) |
| | # ).images[0] |
| | #img.save("majestic_02.png")# |
| |
|
| | #strength = 0.5 |
| | ## --- 4. generate ------------------------------------------------------- |
| | #img = pipe( |
| | # PROMPT, |
| | # negative_prompt=NEG, |
| | # num_inference_steps=STEPS, |
| | # generator=torch.Generator(device="cuda").manual_seed(420) |
| | # ).images[0] |
| | #img.save("majestic_05.png")# |
| | |
| | #strength = 0.75 |
| | ## --- 4. generate ------------------------------------------------------- |
| | #img = pipe( |
| | # PROMPT, |
| | # negative_prompt=NEG, |
| | # num_inference_steps=STEPS, |
| | # generator=torch.Generator(device="cuda").manual_seed(420) |
| | # ).images[0] |
| | #img.save("majestic_075.png") |
| | ``` |