|
|
from dataclasses import dataclass |
|
|
import math |
|
|
|
|
|
import torch |
|
|
from torch import Tensor, nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .modules.layers import ( |
|
|
DoubleStreamBlock, |
|
|
EmbedND, |
|
|
EmbedNDFlux2, |
|
|
LastLayer, |
|
|
MLPEmbedder, |
|
|
Modulation, |
|
|
SingleStreamBlock, |
|
|
timestep_embedding, |
|
|
DistilledGuidance, |
|
|
ChromaModulationOut, |
|
|
SigLIPMultiFeatProjModel, |
|
|
) |
|
|
from .modules.lora import LinearLora, replace_linear_with_lora |
|
|
from .radiance import apply_radiance_head, inject_radiance_modules |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FluxParams: |
|
|
in_channels: int |
|
|
out_channels: int |
|
|
vec_in_dim: int |
|
|
context_in_dim: int |
|
|
hidden_size: int |
|
|
mlp_ratio: float |
|
|
num_heads: int |
|
|
depth: int |
|
|
depth_single_blocks: int |
|
|
axes_dim: list[int] |
|
|
theta: int |
|
|
qkv_bias: bool |
|
|
guidance_embed: bool |
|
|
chroma: bool = False |
|
|
eso: bool = False |
|
|
radiance: bool = False |
|
|
radiance_patch_size: int = 16 |
|
|
radiance_hidden_size: int = 64 |
|
|
radiance_mlp_ratio: int = 4 |
|
|
radiance_depth: int = 4 |
|
|
radiance_max_freqs: int = 8 |
|
|
radiance_tile_size: int = 0 |
|
|
radiance_final_head_type: str = "conv" |
|
|
single_linear1_mlp_ratio: float | None = None |
|
|
single_mlp_hidden_ratio: float | None = None |
|
|
double_mlp_ratio: float | None = None |
|
|
double_linear1_mlp_ratio: float | None = None |
|
|
flux2: bool = False |
|
|
piflow: bool = False |
|
|
|
|
|
class Flux(nn.Module): |
|
|
""" |
|
|
Transformer model for flow matching on sequences. |
|
|
""" |
|
|
def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if block_type == "final": |
|
|
return (tensor[:, -2:-1, :], tensor[:, -1:, :]) |
|
|
single_block_count = self.params.depth_single_blocks |
|
|
double_block_count = self.params.depth |
|
|
offset = 3 * idx |
|
|
if block_type == "single": |
|
|
return ChromaModulationOut.from_offset(tensor, offset) |
|
|
|
|
|
offset *= 2 |
|
|
if block_type in {"double_img", "double_txt"}: |
|
|
|
|
|
offset += 3 * single_block_count |
|
|
if block_type == "double_txt": |
|
|
|
|
|
offset += 6 * double_block_count |
|
|
return ( |
|
|
ChromaModulationOut.from_offset(tensor, offset), |
|
|
ChromaModulationOut.from_offset(tensor, offset + 3), |
|
|
) |
|
|
raise ValueError("Bad block_type") |
|
|
|
|
|
def __init__(self, params: FluxParams): |
|
|
super().__init__() |
|
|
|
|
|
self.params = params |
|
|
self.in_channels = params.in_channels |
|
|
self.out_channels = params.out_channels |
|
|
self.chroma = params.chroma |
|
|
self.is_flux2 = getattr(params, "flux2", False) |
|
|
self.radiance = getattr(params, "radiance", False) |
|
|
self.piflow = getattr(params, "piflow", False) |
|
|
if params.hidden_size % params.num_heads != 0: |
|
|
raise ValueError( |
|
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" |
|
|
) |
|
|
pe_dim = params.hidden_size // params.num_heads |
|
|
if sum(params.axes_dim) != pe_dim: |
|
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") |
|
|
self.hidden_size = params.hidden_size |
|
|
self.num_heads = params.num_heads |
|
|
|
|
|
if self.is_flux2: |
|
|
self.pe_embedder = EmbedNDFlux2(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) |
|
|
else: |
|
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) |
|
|
|
|
|
self.img_in = ( |
|
|
nn.Identity() |
|
|
if self.radiance |
|
|
else nn.Linear(self.in_channels, self.hidden_size, bias=not self.is_flux2) |
|
|
) |
|
|
|
|
|
self.guidance_in = ( |
|
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=not self.is_flux2) |
|
|
if params.guidance_embed |
|
|
else None |
|
|
) |
|
|
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=not self.is_flux2) |
|
|
if self.is_flux2: |
|
|
self.double_stream_modulation_img = Modulation(self.hidden_size, double=True, bias=False) |
|
|
self.double_stream_modulation_txt = Modulation(self.hidden_size, double=True, bias=False) |
|
|
self.single_stream_modulation = Modulation(self.hidden_size, double=False, bias=False) |
|
|
if self.chroma: |
|
|
self.distilled_guidance_layer = DistilledGuidance( |
|
|
in_dim=64, |
|
|
hidden_dim=5120, |
|
|
out_dim=3072, |
|
|
n_layers=5, |
|
|
) |
|
|
else: |
|
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=not self.is_flux2) |
|
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, bias=not self.is_flux2) if not self.is_flux2 else nn.Identity() |
|
|
|
|
|
self.double_blocks = nn.ModuleList( |
|
|
[ |
|
|
DoubleStreamBlock( |
|
|
self.hidden_size, |
|
|
self.num_heads, |
|
|
mlp_ratio=params.mlp_ratio, |
|
|
double_mlp_ratio=getattr(params, "double_mlp_ratio", None), |
|
|
double_linear1_mlp_ratio=getattr(params, "double_linear1_mlp_ratio", None), |
|
|
mod_bias=not self.is_flux2, |
|
|
mlp_bias=not self.is_flux2, |
|
|
proj_bias=not self.is_flux2, |
|
|
qkv_bias=params.qkv_bias, |
|
|
shared_modulation = self.chroma or self.is_flux2, |
|
|
) |
|
|
for _ in range(params.depth) |
|
|
] |
|
|
) |
|
|
|
|
|
self.single_blocks = nn.ModuleList( |
|
|
[ |
|
|
SingleStreamBlock( |
|
|
self.hidden_size, |
|
|
self.num_heads, |
|
|
mlp_ratio=params.mlp_ratio, |
|
|
shared_modulation=self.chroma or self.is_flux2, |
|
|
single_linear1_mlp_ratio=getattr(params, "single_linear1_mlp_ratio", None), |
|
|
single_mlp_hidden_ratio=getattr(params, "single_mlp_hidden_ratio", None), |
|
|
qk_scale=None, |
|
|
linear_bias=not self.is_flux2, |
|
|
modulation_bias=not self.is_flux2, |
|
|
) |
|
|
for _ in range(params.depth_single_blocks) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.radiance: |
|
|
inject_radiance_modules(self, params) |
|
|
self.final_layer = None |
|
|
else: |
|
|
self.final_layer = LastLayer( |
|
|
self.hidden_size, |
|
|
1, |
|
|
self.out_channels, |
|
|
chroma_modulation=self.chroma, |
|
|
use_linear=True, |
|
|
linear_bias=not self.is_flux2, |
|
|
modulation_bias=not self.is_flux2, |
|
|
) |
|
|
|
|
|
if self.piflow: |
|
|
|
|
|
self.proj_out_means = nn.Linear(self.hidden_size, 1024, bias=True) |
|
|
self.proj_out_logweights = nn.Linear(self.hidden_size, 32, bias=True) |
|
|
self.proj_out_logstds = nn.Sequential( |
|
|
nn.Identity(), |
|
|
nn.Linear(self.hidden_size, 1024, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(1024, 1, bias=True), |
|
|
) |
|
|
self.num_gaussians = self.proj_out_means.out_features // self.out_channels |
|
|
self.logweights_channels = self.proj_out_logweights.out_features // self.num_gaussians |
|
|
self.piflow_patch_size = int(math.sqrt(self.logweights_channels)) |
|
|
|
|
|
def _apply_final_layer(self, tokens: Tensor, vec): |
|
|
final_layer = self.final_layer |
|
|
normed = final_layer.norm_final(tokens) |
|
|
if self.chroma: |
|
|
shift, scale = vec |
|
|
shift = shift.squeeze(1) |
|
|
scale = scale.squeeze(1) |
|
|
else: |
|
|
shift, scale = final_layer.adaLN_modulation(vec).chunk(2, dim=1) |
|
|
modulated = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], normed) |
|
|
if final_layer.linear is None: |
|
|
raise RuntimeError("Final layer projection is not available.") |
|
|
base_tokens = final_layer.linear(modulated) |
|
|
return modulated, base_tokens |
|
|
|
|
|
def _apply_piflow_final_layer(self, tokens: Tensor, vec: Tensor, img_ids: Tensor, img_len: int): |
|
|
if img_ids is None: |
|
|
raise RuntimeError("pi-Flow requires image ids to reshape outputs.") |
|
|
final_layer = self.final_layer |
|
|
normed = final_layer.norm_final(tokens) |
|
|
shift, scale = final_layer.adaLN_modulation(vec).chunk(2, dim=1) |
|
|
modulated = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], normed) |
|
|
|
|
|
proj_dtype = self.proj_out_means.weight.dtype |
|
|
modulated = modulated.to(proj_dtype) |
|
|
means = self.proj_out_means(modulated) |
|
|
logweights = self.proj_out_logweights(modulated) |
|
|
logstds = self.proj_out_logstds(vec.detach().to(self.proj_out_logstds[-1].weight.dtype)) |
|
|
|
|
|
base_img_ids = img_ids[:, :img_len] |
|
|
h_len = int(base_img_ids[..., 1].max().item() + 1) |
|
|
w_len = int(base_img_ids[..., 2].max().item() + 1) |
|
|
if h_len * w_len != img_len: |
|
|
raise RuntimeError("pi-Flow token length does not match latent grid.") |
|
|
|
|
|
patch_size = self.piflow_patch_size |
|
|
if patch_size * patch_size != self.logweights_channels: |
|
|
raise RuntimeError("pi-Flow logweights channels mismatch.") |
|
|
|
|
|
bsz = means.shape[0] |
|
|
k = self.num_gaussians |
|
|
c = self.out_channels |
|
|
c_unpacked = c // (patch_size * patch_size) |
|
|
|
|
|
means = means.view( |
|
|
bsz, h_len, w_len, k, c_unpacked, patch_size, patch_size |
|
|
).permute( |
|
|
0, 3, 4, 1, 5, 2, 6 |
|
|
).reshape( |
|
|
bsz, k, c_unpacked, h_len * patch_size, w_len * patch_size |
|
|
) |
|
|
|
|
|
logweights = logweights.view( |
|
|
bsz, h_len, w_len, k, 1, patch_size, patch_size |
|
|
).permute( |
|
|
0, 3, 4, 1, 5, 2, 6 |
|
|
).reshape( |
|
|
bsz, k, 1, h_len * patch_size, w_len * patch_size |
|
|
).log_softmax(dim=1) |
|
|
|
|
|
logstds = logstds.reshape(bsz, 1, 1, 1, 1) |
|
|
return modulated, {"means": means, "logweights": logweights, "logstds": logstds} |
|
|
|
|
|
def preprocess_loras(self, model_type, sd): |
|
|
new_sd = {} |
|
|
if len(sd) == 0: return sd |
|
|
|
|
|
def swap_scale_shift(weight): |
|
|
shift, scale = weight.chunk(2, dim=0) |
|
|
new_weight = torch.cat([scale, shift], dim=0) |
|
|
return new_weight |
|
|
lora_unet = False |
|
|
diffusers = False |
|
|
for k in sd.keys(): |
|
|
if "lora_unet_" in k: |
|
|
lora_unet = True |
|
|
break |
|
|
elif "single_transformer_blocks" in k or "transformer_blocks" in k: |
|
|
diffusers = True |
|
|
break |
|
|
|
|
|
first_key= next(iter(sd)) |
|
|
if lora_unet: |
|
|
new_sd = {} |
|
|
print("Converting Lora Safetensors format to Lora Diffusers format") |
|
|
repl_list = ["linear1", "linear2", "modulation", "img_attn", "txt_attn", "img_mlp", "txt_mlp", "img_mod", "txt_mod"] |
|
|
src_list = ["_" + k + "." for k in repl_list] |
|
|
src_list2 = ["_" + k + "_" for k in repl_list] |
|
|
tgt_list = ["." + k + "." for k in repl_list] |
|
|
|
|
|
for k,v in sd.items(): |
|
|
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.") |
|
|
k = k.replace("lora_unet__blocks_","diffusion_model.blocks.") |
|
|
k = k.replace("lora_unet_single_blocks_","diffusion_model.single_blocks.") |
|
|
k = k.replace("lora_unet_double_blocks_","diffusion_model.double_blocks.") |
|
|
|
|
|
for s,s2, t in zip(src_list, src_list2, tgt_list): |
|
|
k = k.replace(s,t) |
|
|
k = k.replace(s2,t) |
|
|
|
|
|
k = k.replace("lora_up","lora_B") |
|
|
k = k.replace("lora_down","lora_A") |
|
|
|
|
|
new_sd[k] = v |
|
|
|
|
|
elif diffusers: |
|
|
root_src = ["time_text_embed.timestep_embedder.linear_1", "time_text_embed.timestep_embedder.linear_2", "time_text_embed.text_embedder.linear_1", "time_text_embed.text_embedder.linear_2", |
|
|
"time_text_embed.guidance_embedder.linear_1", "time_text_embed.guidance_embedder.linear_2", |
|
|
"x_embedder", "context_embedder", "proj_out", "time_guidance_embed.timestep_embedder.linear_1", "time_guidance_embed.timestep_embedder.linear_2" ] |
|
|
|
|
|
root_tgt = ["time_in.in_layer", "time_in.out_layer", "vector_in.in_layer", "vector_in.out_layer", |
|
|
"guidance_in.in_layer", "guidance_in.out_layer", |
|
|
"img_in", "txt_in", "final_layer.linear", "time_in.in_layer", "time_in.out_layer" ] |
|
|
|
|
|
double_src = ["norm1.linear", "norm1_context.linear", "attn.norm_q", "attn.norm_k", "ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", "attn.to_out.0" ,"attn.to_add_out", "attn.to_out", ".attn.to_", ".attn.add_q_proj.", ".attn.add_k_proj.", ".attn.add_v_proj.", ".ff_context.linear_out.", ".ff_context.linear_in.", ".ff.linear_out.", ".ff.linear_in." ] |
|
|
double_tgt = ["img_mod.lin", "txt_mod.lin", "img_attn.norm.query_norm", "img_attn.norm.key_norm", "img_mlp.0", "img_mlp.2", "txt_mlp.0", "txt_mlp.2", "img_attn.proj", "txt_attn.proj", "img_attn.proj", ".img_attn.", ".txt_attn.q.", ".txt_attn.k.", ".txt_attn.v.", ".txt_mlp.2.", ".txt_mlp.0.", ".img_mlp.2.", ".img_mlp.0." ] |
|
|
|
|
|
single_src = ["norm.linear", "attn.norm_q", "attn.norm_k", "proj_out",".attn.to_q.", ".attn.to_k.", ".attn.to_v.", ".proj_mlp.", ".attn.to_out."] |
|
|
single_tgt = ["modulation.lin","norm.query_norm", "norm.key_norm", "linear2", ".linear1_attn_q.", ".linear1_attn_k.", ".linear1_attn_v.", ".linear1_mlp.", ".linear2."] |
|
|
|
|
|
|
|
|
for k,v in sd.items(): |
|
|
if k.startswith("transformer."): |
|
|
k = k.replace("transformer.", "") |
|
|
if k.startswith("single_transformer_blocks"): |
|
|
k = k.replace("single_transformer_blocks", "single_blocks") |
|
|
for src, tgt in zip(single_src, single_tgt): |
|
|
k = k.replace(src, tgt) |
|
|
elif k.startswith("transformer_blocks"): |
|
|
k = k.replace("transformer_blocks", "double_blocks") |
|
|
for src, tgt in zip(double_src, double_tgt): |
|
|
k = k.replace(src, tgt) |
|
|
else: |
|
|
for src, tgt in zip(root_src, root_tgt): |
|
|
k = k.replace(src, tgt) |
|
|
|
|
|
if "norm_out.linear" in k: |
|
|
if "lora_B" in k: |
|
|
v = swap_scale_shift(v) |
|
|
k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1") |
|
|
if not k.startswith("diffusion_model."): |
|
|
k = "diffusion_model." + k |
|
|
|
|
|
new_sd[k] = v |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
new_sd = sd |
|
|
return new_sd |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
img: Tensor, |
|
|
img_ids: Tensor, |
|
|
txt_list, |
|
|
txt_ids_list, |
|
|
timesteps: Tensor, |
|
|
y_list, |
|
|
img_len = 0, |
|
|
guidance: Tensor | None = None, |
|
|
callback= None, |
|
|
pipeline =None, |
|
|
siglip_embedding = None, |
|
|
siglip_embedding_ids = None, |
|
|
) -> Tensor: |
|
|
|
|
|
sz = len(txt_list) |
|
|
height = width = None |
|
|
base_image_list = None |
|
|
if self.radiance: |
|
|
patch_size = self.patch_size |
|
|
|
|
|
base_ids = img_ids[:, :img_len, :] |
|
|
height = int(base_ids[..., 1].max().item() + 1) * patch_size |
|
|
width = int(base_ids[..., 2].max().item() + 1) * patch_size |
|
|
|
|
|
tokens = img[:, :img_len, :].transpose(1, 2) |
|
|
image = F.fold( |
|
|
tokens, |
|
|
output_size=(height, width), |
|
|
kernel_size=patch_size, |
|
|
stride=patch_size, |
|
|
) |
|
|
|
|
|
hidden = self.img_in_patch(image).flatten(2).transpose(1, 2) |
|
|
img = hidden |
|
|
base_image_list = [image] if sz == 1 else [image, image.clone()] |
|
|
else: |
|
|
img = self.img_in(img) |
|
|
img_list = [img] if sz == 1 else [img, img.clone()] |
|
|
|
|
|
if self.chroma: |
|
|
mod_index_length = 344 |
|
|
distill_timestep = timestep_embedding(timesteps, 16).to(img.device, img.dtype) |
|
|
guidance = torch.tensor([0.]* distill_timestep.shape[0]) |
|
|
distil_guidance = timestep_embedding(guidance, 16).to(img.device, img.dtype) |
|
|
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype) |
|
|
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype) |
|
|
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype) |
|
|
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype) |
|
|
mod_vectors = self.distilled_guidance_layer(input_vec) |
|
|
else: |
|
|
vec = self.time_in(timestep_embedding(timesteps, 256)) |
|
|
if self.params.guidance_embed and self.guidance_in is not None: |
|
|
if guidance is None: |
|
|
raise ValueError("Didn't get guidance strength for guidance distilled model.") |
|
|
vec += self.guidance_in(timestep_embedding(guidance, 256)) |
|
|
base_vec_list = [vec + self.vector_in(y) for y in y_list] |
|
|
vec_list = base_vec_list |
|
|
|
|
|
img = None |
|
|
txt_list = [self.txt_in(txt) for txt in txt_list ] |
|
|
if siglip_embedding is not None: |
|
|
txt_list = [torch.cat((siglip_embedding, txt) , dim=1) for txt in txt_list] |
|
|
txt_ids_list = [torch.cat((siglip_embedding_ids, txt_id) , dim=1) for txt_id in txt_ids_list] |
|
|
|
|
|
pe_list = [self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1)) for txt_ids in txt_ids_list] |
|
|
|
|
|
if self.is_flux2: |
|
|
double_vec_list = [ ( self.double_stream_modulation_img(base_vec_list[i]), self.double_stream_modulation_txt(base_vec_list[i]), ) for i in range(sz) ] |
|
|
|
|
|
for i, block in enumerate(self.double_blocks): |
|
|
if self.chroma: |
|
|
vec_list = [( self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_txt", idx=i))] * sz |
|
|
elif self.is_flux2: |
|
|
vec_list = double_vec_list |
|
|
if callback != None: |
|
|
callback(-1, None, False, True) |
|
|
if pipeline._interrupt: |
|
|
return [None] * sz |
|
|
for img, txt, pe, vec in zip(img_list, txt_list, pe_list, vec_list): |
|
|
img[...], txt[...] = block(img=img, txt=txt, vec=vec, pe=pe) |
|
|
img = txt = pe = vec= None |
|
|
|
|
|
img_list = [torch.cat((txt, img), 1) for txt, img in zip(txt_list, img_list)] |
|
|
|
|
|
if self.is_flux2: |
|
|
single_vec_list = [self.single_stream_modulation(base_vec_list[i])[0] for i in range(sz)] |
|
|
|
|
|
for i, block in enumerate(self.single_blocks): |
|
|
if self.chroma: |
|
|
vec_list= [self.get_modulations(mod_vectors, "single", idx=i)] * sz |
|
|
elif self.is_flux2: |
|
|
vec_list = single_vec_list |
|
|
|
|
|
if callback != None: |
|
|
callback(-1, None, False, True) |
|
|
if pipeline._interrupt: |
|
|
return [None] * sz |
|
|
for img, pe, vec in zip(img_list, pe_list, vec_list): |
|
|
img[...]= block(x=img, vec=vec, pe=pe) |
|
|
img = pe = vec = None |
|
|
img_list = [img[:, txt.shape[1] : txt.shape[1] + img_len, ...] for img, txt in zip(img_list, txt_list)] |
|
|
|
|
|
if self.radiance: |
|
|
final_vecs = None |
|
|
elif self.chroma: |
|
|
final_vecs = [self.get_modulations(mod_vectors, "final")] * sz |
|
|
else: |
|
|
final_vecs = base_vec_list |
|
|
out_list = [] |
|
|
for i in range(sz): |
|
|
hidden_seq = img_list[i] |
|
|
if self.radiance: |
|
|
base_image = base_image_list[i] |
|
|
pred_tokens = apply_radiance_head( |
|
|
module=self, |
|
|
hidden_seq=hidden_seq, |
|
|
base_image=base_image, |
|
|
height=height, |
|
|
width=width, |
|
|
) |
|
|
base_image_list[i] = base_image = None |
|
|
img_list[i] = hidden_seq = None |
|
|
else: |
|
|
vec = final_vecs[i] |
|
|
if self.piflow: |
|
|
modulated, pred_tokens = self._apply_piflow_final_layer( |
|
|
hidden_seq, vec, img_ids, img_len |
|
|
) |
|
|
else: |
|
|
modulated, pred_tokens = self._apply_final_layer(hidden_seq, vec) |
|
|
img_list[i] = hidden_seq = vec = modulated = None |
|
|
out_list.append(pred_tokens) |
|
|
return out_list |
|
|
|
|
|
|
|
|
class FluxLoraWrapper(Flux): |
|
|
def __init__( |
|
|
self, |
|
|
lora_rank: int = 128, |
|
|
lora_scale: float = 1.0, |
|
|
*args, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
self.lora_rank = lora_rank |
|
|
|
|
|
replace_linear_with_lora( |
|
|
self, |
|
|
max_rank=lora_rank, |
|
|
scale=lora_scale, |
|
|
) |
|
|
|
|
|
def set_lora_scale(self, scale: float) -> None: |
|
|
for module in self.modules(): |
|
|
if isinstance(module, LinearLora): |
|
|
module.set_scale(scale=scale) |
|
|
|