Issues with LoRA training on a Dual-GPU setup
Because I am using a dual-GPU configuration, standard training tools like ComfyUI or Kohya don't seem to support my hardware setup efficiently. Consequently, I’ve developed a custom LoRA training script using the Diffusers library.
I am currently using a dataset of 64 high-quality images, each with detailed natural language captions. Despite my efforts, I keep encountering severe issues—either anatomical deformations or the output collapsing into pure noise. My goal is to capture a specific animation style, but the model fails to converge correctly.
I’ve tried targeting different layers (MLP vs. Attention), but the results remain unstable. Could you please review my implementation, especially how I'm handling the model parallelism and layer targeting?
Thank you.
==================================================
import os
import torch
import torch.nn.functional as F
import bitsandbytes as bnb
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from diffusers import Flux2KleinPipeline
from diffusers.optimization import get_scheduler
from tqdm import tqdm
from peft import LoraConfig, get_peft_model
from peft.utils import get_peft_model_state_dict
MODEL_PATH = "/mnt/d/FLUX.2-klein-base-9B"
DATASET_DIR = "/mnt/d/FLUX.2-klein-base-9B"
OUTPUT_DIR = "./flux2_lora_output"
BATCH_SIZE = 2
EPOCHS = 200
LEARNING_RATE = 1e-4
GRADIENT_ACCUMULATION_STEPS = 4
LORA_RANK = 32
LORA_ALPHA = 32
os.makedirs(OUTPUT_DIR, exist_ok=True)
class FluxDataset(Dataset):
def init(self, data_dir, size=512):
self.data_dir = data_dir
self.size = size
self.image_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
self.transform = transforms.Compose([
transforms.Resize((1088, 768)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert("RGB")
image = self.transform(image)
txt_path = os.path.splitext(img_path)[0] + ".txt"
prompt = ""
if os.path.exists(txt_path):
with open(txt_path, "r", encoding="utf-8") as f:
prompt = f.read().strip()
return {"image": image, "prompt": prompt}
def main():
print("🚀 啟動 FLUX.2 ...")
print("➤ 正在透過實體 RAM 載入全模型...")
pipe = Flux2KleinPipeline.from_pretrained(
MODEL_PATH,
local_files_only=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False
)
print("➤ 正在修復 Qwen3 聊天模板...")
safe_chat_template = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\\n'}}"
"{% if message['content'] is string %}"
"{{ message['content'] }}"
"{% else %}"
"{% for item in message['content'] %}"
"{% if item['type'] == 'image' %}[IMG]"
"{% elif item['type'] == 'text' %}{{ item['text'] }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"{{'<|im_end|>\\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}"
)
if hasattr(pipe, "tokenizer_2") and pipe.tokenizer_2 is not None:
pipe.tokenizer_2.chat_template = safe_chat_template
elif hasattr(pipe, "tokenizer") and pipe.tokenizer is not None:
pipe.tokenizer.chat_template = safe_chat_template
print("➤ 將 Qwen3 (Text Encoder) 移至 GPU 1 並凍結...")
text_encoder = pipe.text_encoder.to("cuda:1")
text_encoder.requires_grad_(False)
tokenizer = pipe.tokenizer_2 if hasattr(pipe, "tokenizer_2") else pipe.tokenizer
print("➤ 將 VAE 移至 GPU 0 並凍結...")
vae = pipe.vae.to(device="cuda:0", dtype=torch.float32)
vae.requires_grad_(False)
print("➤ 將 Transformer 移至 GPU 0...")
transformer = pipe.transformer.to("cuda:0")
transformer.requires_grad_(False)
print("➤ 正在架設 Pipeline 跨卡編碼橋樑...")
import types
original_encode_prompt = pipe.encode_prompt
def patched_encode_prompt(self, *args, **kwargs):
kwargs['device'] = torch.device("cuda:1")
result = original_encode_prompt(*args, **kwargs)
if isinstance(result, tuple):
return tuple(r.to("cuda:0") if hasattr(r, "to") else r for r in result)
return result.to("cuda:0") if hasattr(result, "to") else result
pipe.encode_prompt = types.MethodType(patched_encode_prompt, pipe)
transformer.enable_gradient_checkpointing()
single_out_layers = [f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)]
print("➤ 正在注入 LoRA 權重...")
lora_config = LoraConfig(
r=LORA_RANK,
lora_alpha=LORA_ALPHA,
init_lora_weights=True,
target_modules=[
# Joint Blocks Attention
"to_q", "to_k", "to_v", "to_out.0",
"add_q_proj", "add_k_proj", "add_v_proj", "to_add_out",
# Joint Blocks MLPs
"ff.linear_in", "ff.linear_out",
"ff_context.linear_in", "ff_context.linear_out",
# Single Blocks (FLUX fuses Attention and MLPs together here)
"to_qkv_mlp_proj"
] + single_out_layers,
)
print("\n" + "="*50)
print("📊 LoRA 掛載狀態體檢報告")
print("="*50)
transformer = get_peft_model(transformer, lora_config)
transformer.print_trainable_parameters()
trainable_layer_names = [name for name, _ in transformer.named_parameters() if "lora" in name]
if len(trainable_layer_names) == 0:
print("🚨 沒有任何一層成功掛上 LoRA!請檢查 target_modules 名稱是否全錯!")
else:
print(f"\n✅ 成功掛載了 {len(trainable_layer_names)} 個 LoRA 張量矩陣!")
for name in trainable_layer_names:
print(f" - {name}")
print("="*50 + "\n")
print("➤ 準備優化器與 DataLoader...")
optimizer = bnb.optim.AdamW8bit(
transformer.parameters(),
lr=LEARNING_RATE,
weight_decay=1e-4
)
dataset = FluxDataset(DATASET_DIR)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
total_steps = len(dataloader) * EPOCHS
lr_scheduler = get_scheduler(
"cosine",
optimizer=optimizer,
num_warmup_steps=min(100, total_steps // 10),
num_training_steps=total_steps
)
print("🔥 開始訓練!")
global_step = 0
transformer.train()
optimizer.zero_grad()
for epoch in range(EPOCHS):
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
for batch in progress_bar:
images = batch["image"].to("cuda:0", dtype=torch.float32)
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample().to(torch.bfloat16)
shift = 0.1159
scale = 0.3611
latents = (latents - shift) * scale
b, c, h, w = latents.shape
latents_packed = latents.view(b, c, h // 2, 2, w // 2, 2)
latents_packed = latents_packed.permute(0, 2, 4, 1, 3, 5)
latents_packed = latents_packed.reshape(b, (h // 2) * (w // 2), c * 4)
img_ids = torch.zeros((h // 2, w // 2, 4), device="cuda:0", dtype=torch.bfloat16)
img_ids[..., 2] = img_ids[..., 2] + torch.arange(h // 2, device="cuda:0")[:, None]
img_ids[..., 3] = img_ids[..., 3] + torch.arange(w // 2, device="cuda:0")[None, :]
img_ids = img_ids.reshape(-1, 4).unsqueeze(0).repeat(b, 1, 1)
prompts = batch["prompt"]
# [CRITICAL FIX] 10% Prompt Dropout for CFG Support
import random
if random.random() < 0.1:
prompts = [""] * len(prompts)
with torch.no_grad():
encode_out = pipe.encode_prompt(prompt=prompts, max_sequence_length=512)
encoder_hidden_states = encode_out[0]
real_txt_ids = encode_out[-1]
encoder_hidden_states = encoder_hidden_states.to("cuda:0", dtype=torch.bfloat16)
txt_ids = real_txt_ids.to("cuda:0", dtype=torch.bfloat16)
# 3. Flow Matching
z_1 = latents_packed
z_0 = torch.randn_like(z_1)
bsz = z_1.shape[0]
target = z_0 - z_1
u = torch.randn((bsz,), device="cuda:0", dtype=torch.float32)
t = torch.sigmoid(u).to(torch.bfloat16)
t_expand = t.view(-1, 1, 1)
z_t = t_expand * z_0 + (1.0 - t_expand) * z_1
timestep_input = t.to(torch.float32)
with torch.autocast("cuda", dtype=torch.bfloat16):
model_pred = transformer(
hidden_states=z_t,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
timestep=timestep_input, # You were right: keep this as 0.0 ~ 1.0
return_dict=False
)[0]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
loss = loss / GRADIENT_ACCUMULATION_STEPS
loss.backward()
global_step += 1
if global_step % GRADIENT_ACCUMULATION_STEPS == 0:
torch.nn.utils.clip_grad_norm_(transformer.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.set_postfix({"loss": f"{(loss.item() * GRADIENT_ACCUMULATION_STEPS):.4f}"})
current_epoch = epoch + 1
if current_epoch % 10 == 0:
checkpoint_dir = os.path.join(OUTPUT_DIR, f"checkpoint-epoch-{current_epoch}")
os.makedirs(checkpoint_dir, exist_ok=True)
lora_state_dict = get_peft_model_state_dict(transformer)
clean_state_dict = {k.replace("base_model.model.transformer.", ""): v for k, v in lora_state_dict.items()}
Flux2KleinPipeline.save_lora_weights(
save_directory=checkpoint_dir,
transformer_lora_layers=clean_state_dict,
weight_name="flux2_lora.safetensors"
)
print(f"\n💾 儲存第 {current_epoch} 個 Epoch 至 {checkpoint_dir}...")
print(f"🎉 訓練完成!正在儲存最終版本至 {OUTPUT_DIR}...")
final_lora_state_dict = get_peft_model_state_dict(transformer)
clean_final_state_dict = {k.replace("base_model.model.", ""): v for k, v in final_lora_state_dict.items()}
Flux2KleinPipeline.save_lora_weights(
save_directory=OUTPUT_DIR,
transformer_lora_layers=clean_final_state_dict,
weight_name="flux2_lora_final.safetensors"
)
if name == "main":
main()