|
|
""" |
|
|
LECO Attribute Binding Trainer - COMPLETE WITH PROPER FLOW MATCHING |
|
|
Complete script with correct flow matching SNR and velocity prediction |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import datetime |
|
|
import random |
|
|
from dataclasses import dataclass, asdict, field |
|
|
from typing import List, Tuple |
|
|
from tqdm.auto import tqdm |
|
|
from itertools import product |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
from diffusers import UNet2DConditionModel |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class AttributePair: |
|
|
"""A specific combination that should stay distinct""" |
|
|
attr1: str |
|
|
attr2: str |
|
|
negatives: Tuple[str, ...] = () |
|
|
weight: float = 1.0 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AttributeBindingConfig: |
|
|
"""Config for attribute binding training""" |
|
|
output_dir: str = "./leco_outputs" |
|
|
base_model_repo: str = "AbstractPhil/sd15-flow-lune-flux" |
|
|
base_checkpoint: str = "sd15_flow_flux_t2_6_pose_t4_6_port_t1_4_s18765.pt" |
|
|
name_prefix: str = "leco" |
|
|
|
|
|
attribute_pairs: List[AttributePair] = field(default_factory=list) |
|
|
|
|
|
lora_rank: int = 8 |
|
|
lora_alpha: float = 1.0 |
|
|
training_method: str = "xattn" |
|
|
|
|
|
seed: int = 42 |
|
|
iterations: int = 500 |
|
|
save_every: int = 250 |
|
|
lr: float = 2e-4 |
|
|
pairs_per_batch: int = 4 |
|
|
negatives_per_positive: int = 2 |
|
|
|
|
|
|
|
|
use_min_snr: bool = True |
|
|
min_snr_gamma: float = 5.0 |
|
|
|
|
|
|
|
|
shift: float = 2.5 |
|
|
min_timestep: float = 0.0 |
|
|
max_timestep: float = 1000.0 |
|
|
resolution: int = 512 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LECOConfig: |
|
|
"""Minimal config for LoRA creation""" |
|
|
lora_rank: int = 4 |
|
|
lora_alpha: float = 1.0 |
|
|
training_method: str = "xattn" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_color(text: str) -> str: |
|
|
"""Extract color from text""" |
|
|
colors = [ |
|
|
"red", "blue", "green", "yellow", "purple", "orange", "pink", |
|
|
"black", "white", "brown", "blonde", "silver", "gold", "cyan", |
|
|
"magenta", "teal", "lavender", "gray", "grey", "beige", "navy", |
|
|
"maroon", "turquoise", "violet", "indigo", "crimson" |
|
|
] |
|
|
text_lower = text.lower() |
|
|
for color in colors: |
|
|
if color in text_lower: |
|
|
return color |
|
|
return None |
|
|
|
|
|
|
|
|
def generate_smart_negatives(attr1: str, attr2: str, all_negatives: List[str] = None) -> List[str]: |
|
|
"""Automatically generate wrong combinations""" |
|
|
negatives = [] |
|
|
|
|
|
color1 = extract_color(attr1) |
|
|
color2 = extract_color(attr2) |
|
|
|
|
|
if color1 and color2 and color1 != color2: |
|
|
swapped_attr1 = attr1.replace(color1, color2) |
|
|
swapped_attr2 = attr2.replace(color2, color1) |
|
|
negatives.append(f"{swapped_attr1}, {swapped_attr2}") |
|
|
negatives.append(f"{attr1}, {attr2.replace(color2, color1)}") |
|
|
negatives.append(f"{attr1.replace(color1, color2)}, {attr2}") |
|
|
|
|
|
|
|
|
if all_negatives: |
|
|
for neg in all_negatives: |
|
|
negatives.append(f"{attr1}, {attr2}, {neg}") |
|
|
|
|
|
return list(set(negatives)) |
|
|
|
|
|
|
|
|
def create_attribute_combinations( |
|
|
pair_attr1: List[str], |
|
|
pair_attr2: List[str], |
|
|
negatives: List[str] = None, |
|
|
weight: float = 1.0, |
|
|
auto_generate_negatives: bool = True |
|
|
) -> List[AttributePair]: |
|
|
"""Create all combinations of two attribute lists""" |
|
|
pairs = [] |
|
|
|
|
|
for attr1, attr2 in product(pair_attr1, pair_attr2): |
|
|
if auto_generate_negatives: |
|
|
neg_list = generate_smart_negatives(attr1, attr2, negatives) |
|
|
else: |
|
|
neg_list = [] |
|
|
if negatives: |
|
|
for neg in negatives: |
|
|
neg_list.append(f"{attr1}, {neg}") |
|
|
neg_list.append(f"{neg}, {attr2}") |
|
|
|
|
|
pairs.append(AttributePair( |
|
|
attr1=attr1, |
|
|
attr2=attr2, |
|
|
negatives=tuple(neg_list), |
|
|
weight=weight |
|
|
)) |
|
|
|
|
|
return pairs |
|
|
|
|
|
|
|
|
def combine_attribute_groups(*groups: List[AttributePair]) -> List[AttributePair]: |
|
|
"""Combine multiple attribute groups""" |
|
|
combined = [] |
|
|
for group in groups: |
|
|
combined.extend(group) |
|
|
return combined |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_target_modules(training_method: str) -> List[str]: |
|
|
"""Get layer names to inject LoRA""" |
|
|
attn1 = ["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0"] |
|
|
attn2 = ["attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out.0"] |
|
|
|
|
|
method_map = { |
|
|
"full": attn1 + attn2, |
|
|
"selfattn": attn1, |
|
|
"xattn": attn2, |
|
|
"noxattn": attn1, |
|
|
"innoxattn": attn2 |
|
|
} |
|
|
return method_map.get(training_method, attn1 + attn2) |
|
|
|
|
|
|
|
|
def create_lora_layers(unet: nn.Module, config: LECOConfig): |
|
|
"""Create LoRA layers""" |
|
|
target_modules = get_target_modules(config.training_method) |
|
|
lora_state = {} |
|
|
trainable_params = [] |
|
|
|
|
|
def get_lora_key(module_path: str) -> str: |
|
|
return f"lora_unet_{module_path.replace('.', '_')}" |
|
|
|
|
|
print(f"Creating LoRA layers (method: {config.training_method})...") |
|
|
|
|
|
for name, module in unet.named_modules(): |
|
|
if not any(target in name for target in target_modules): |
|
|
continue |
|
|
|
|
|
if not isinstance(module, nn.Linear): |
|
|
continue |
|
|
|
|
|
lora_key = get_lora_key(name) |
|
|
in_dim = module.in_features |
|
|
out_dim = module.out_features |
|
|
rank = config.lora_rank |
|
|
|
|
|
lora_down = nn.Parameter(torch.zeros(rank, in_dim)) |
|
|
lora_up = nn.Parameter(torch.zeros(out_dim, rank)) |
|
|
|
|
|
nn.init.kaiming_uniform_(lora_down, a=1.0) |
|
|
nn.init.zeros_(lora_up) |
|
|
|
|
|
lora_state[f"{lora_key}.lora_down.weight"] = lora_down |
|
|
lora_state[f"{lora_key}.lora_up.weight"] = lora_up |
|
|
lora_state[f"{lora_key}.alpha"] = torch.tensor(config.lora_alpha) |
|
|
lora_state[f"{lora_key}._module"] = module |
|
|
|
|
|
trainable_params.extend([lora_down, lora_up]) |
|
|
|
|
|
print(f"✓ Created {len(trainable_params)//2} LoRA layers ({len(trainable_params)} parameters)") |
|
|
return lora_state, trainable_params |
|
|
|
|
|
|
|
|
def apply_lora_hooks(unet: nn.Module, lora_state: dict, scale: float = 1.0) -> list: |
|
|
"""Apply LoRA using forward hooks""" |
|
|
handles = [] |
|
|
|
|
|
for key in lora_state: |
|
|
if not key.endswith(".lora_down.weight"): |
|
|
continue |
|
|
|
|
|
base_key = key.replace(".lora_down.weight", "") |
|
|
module = lora_state[f"{base_key}._module"] |
|
|
lora_down = lora_state[f"{base_key}.lora_down.weight"] |
|
|
lora_up = lora_state[f"{base_key}.lora_up.weight"] |
|
|
alpha = lora_state[f"{base_key}.alpha"].item() |
|
|
rank = lora_down.shape[0] |
|
|
|
|
|
scaling = (alpha / rank) * scale |
|
|
|
|
|
def make_hook(down, up, s): |
|
|
def forward_hook(mod, inp, out): |
|
|
x = inp[0] |
|
|
lora_out = F.linear(F.linear(x, down), up) |
|
|
return out + lora_out * s |
|
|
return forward_hook |
|
|
|
|
|
handle = module.register_forward_hook(make_hook(lora_down, lora_up, scaling)) |
|
|
handles.append(handle) |
|
|
|
|
|
return handles |
|
|
|
|
|
|
|
|
def remove_lora_hooks(handles: list): |
|
|
"""Remove all LoRA hooks""" |
|
|
for handle in handles: |
|
|
handle.remove() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_attribute_binding_loss_batched( |
|
|
unet, |
|
|
lora_state, |
|
|
positive_pairs: List[AttributePair], |
|
|
tokenizer, |
|
|
text_encoder, |
|
|
config: AttributeBindingConfig, |
|
|
device: str = "cuda" |
|
|
): |
|
|
"""Batched attribute binding with PROPER FLOW MATCHING""" |
|
|
|
|
|
|
|
|
min_sigma = config.min_timestep / 1000.0 |
|
|
max_sigma = config.max_timestep / 1000.0 |
|
|
|
|
|
sigma = torch.rand(1, device=device) |
|
|
sigma = min_sigma + sigma * (max_sigma - min_sigma) |
|
|
|
|
|
|
|
|
sigma = (config.shift * sigma) / (1 + (config.shift - 1) * sigma) |
|
|
timestep = sigma * 1000.0 |
|
|
sigma_expanded = sigma.view(1, 1, 1, 1) |
|
|
|
|
|
|
|
|
|
|
|
noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device) |
|
|
noisy_input = sigma_expanded * noise |
|
|
|
|
|
|
|
|
positive_prompts = [] |
|
|
negative_prompts = [] |
|
|
pair_weights = [] |
|
|
|
|
|
for pair in positive_pairs: |
|
|
correct = f"{pair.attr1}, {pair.attr2}" |
|
|
positive_prompts.append(correct) |
|
|
pair_weights.append(pair.weight) |
|
|
|
|
|
if pair.negatives: |
|
|
sampled_negs = random.sample( |
|
|
list(pair.negatives), |
|
|
min(config.negatives_per_positive, len(pair.negatives)) |
|
|
) |
|
|
negative_prompts.extend(sampled_negs) |
|
|
|
|
|
if not positive_prompts: |
|
|
return torch.tensor(0.0, device=device), { |
|
|
"positive_loss": 0, "negative_loss": 0, |
|
|
"positive_count": 0, "negative_count": 0, |
|
|
"timestep": 0.0, "snr_weight": 1.0 |
|
|
} |
|
|
|
|
|
neutral_prompt = "" |
|
|
all_prompts = [neutral_prompt] + positive_prompts + negative_prompts |
|
|
|
|
|
text_inputs = tokenizer( |
|
|
all_prompts, |
|
|
padding="max_length", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
|
|
|
all_embeddings = text_encoder(text_inputs.input_ids)[0] |
|
|
|
|
|
neutral_emb = all_embeddings[0:1] |
|
|
positive_embs = all_embeddings[1:1+len(positive_prompts)] |
|
|
negative_embs = all_embeddings[1+len(positive_prompts):] |
|
|
|
|
|
batch_size = len(all_prompts) - 1 |
|
|
noisy_input_batch = noisy_input.repeat(batch_size, 1, 1, 1) |
|
|
timestep_batch = timestep.repeat(batch_size) |
|
|
|
|
|
combined_embs = torch.cat([positive_embs, negative_embs], dim=0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
vel_neutral = unet( |
|
|
noisy_input, timestep_batch[0:1], |
|
|
encoder_hidden_states=neutral_emb, |
|
|
return_dict=False |
|
|
)[0] |
|
|
|
|
|
vel_baseline = unet( |
|
|
noisy_input_batch, timestep_batch, |
|
|
encoder_hidden_states=combined_embs, |
|
|
return_dict=False |
|
|
)[0] |
|
|
|
|
|
vel_positive_baseline = vel_baseline[:len(positive_prompts)] |
|
|
vel_negative_baseline = vel_baseline[len(positive_prompts):] |
|
|
|
|
|
handles = apply_lora_hooks(unet, lora_state, scale=1.0) |
|
|
|
|
|
try: |
|
|
vel_with_lora = unet( |
|
|
noisy_input_batch, timestep_batch, |
|
|
encoder_hidden_states=combined_embs, |
|
|
return_dict=False |
|
|
)[0] |
|
|
finally: |
|
|
remove_lora_hooks(handles) |
|
|
|
|
|
vel_positive_lora = vel_with_lora[:len(positive_prompts)] |
|
|
vel_negative_lora = vel_with_lora[len(positive_prompts):] |
|
|
|
|
|
|
|
|
snr_weight = 1.0 |
|
|
if config.use_min_snr: |
|
|
|
|
|
sigma_sq = sigma.squeeze() ** 2 |
|
|
snr = ((1 - sigma.squeeze()) ** 2) / (sigma_sq + 1e-8) |
|
|
|
|
|
|
|
|
snr_clamped = torch.minimum(snr, torch.tensor(config.min_snr_gamma, device=device)) |
|
|
snr_weight_tensor = snr_clamped / snr |
|
|
|
|
|
|
|
|
snr_weight_tensor = snr_weight_tensor / (snr + 1) |
|
|
|
|
|
snr_weight = snr_weight_tensor.item() |
|
|
else: |
|
|
snr_weight_tensor = torch.ones(1, device=device) |
|
|
|
|
|
|
|
|
vel_neutral_expanded = vel_neutral.expand_as(vel_positive_baseline) |
|
|
target_positive_direction = vel_positive_baseline - vel_neutral_expanded |
|
|
lora_positive_delta = vel_positive_lora - vel_positive_baseline |
|
|
|
|
|
positive_loss_per_sample = F.mse_loss( |
|
|
lora_positive_delta, |
|
|
target_positive_direction * 0.3, |
|
|
reduction='none' |
|
|
).mean(dim=(1,2,3)) |
|
|
|
|
|
|
|
|
pair_weights_tensor = torch.tensor(pair_weights, device=device) |
|
|
weighted_positive_loss = (positive_loss_per_sample * pair_weights_tensor * snr_weight_tensor).mean() |
|
|
|
|
|
negative_loss = torch.tensor(0.0, device=device) |
|
|
lora_negative_norm = 0.0 |
|
|
|
|
|
if len(negative_prompts) > 0: |
|
|
vel_neutral_expanded_neg = vel_neutral.expand_as(vel_negative_baseline) |
|
|
target_negative_direction = vel_neutral_expanded_neg - vel_negative_baseline |
|
|
lora_negative_delta = vel_negative_lora - vel_negative_baseline |
|
|
|
|
|
negative_loss = F.mse_loss(lora_negative_delta, target_negative_direction * 0.2, reduction='mean') |
|
|
negative_loss = negative_loss * snr_weight_tensor |
|
|
lora_negative_norm = lora_negative_delta.norm().item() |
|
|
|
|
|
total_loss = weighted_positive_loss + negative_loss * 0.5 |
|
|
|
|
|
metrics = { |
|
|
"positive_loss": weighted_positive_loss.item(), |
|
|
"negative_loss": negative_loss.item() if isinstance(negative_loss, torch.Tensor) else 0.0, |
|
|
"positive_count": len(positive_prompts), |
|
|
"negative_count": len(negative_prompts), |
|
|
"timestep": timestep.item(), |
|
|
"sigma": sigma.item(), |
|
|
"snr_weight": snr_weight, |
|
|
"lora_positive_norm": lora_positive_delta.norm().item(), |
|
|
"lora_negative_norm": lora_negative_norm |
|
|
} |
|
|
|
|
|
return total_loss, metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_attribute_binding(config: AttributeBindingConfig): |
|
|
"""Fast training for attribute binding with Min-SNR""" |
|
|
device = "cuda" |
|
|
torch.manual_seed(config.seed) |
|
|
|
|
|
if not config.attribute_pairs: |
|
|
raise ValueError("No attribute pairs specified!") |
|
|
|
|
|
pairs_with_negatives = sum(1 for p in config.attribute_pairs if p.negatives) |
|
|
print(f"Pairs with explicit negatives: {pairs_with_negatives}/{len(config.attribute_pairs)}") |
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
output_dir = os.path.join(config.output_dir, f"attribute_binding_{timestamp}") |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
writer = SummaryWriter(log_dir=output_dir, flush_secs=60) |
|
|
|
|
|
with open(os.path.join(output_dir, "config.json"), "w") as f: |
|
|
json.dump(asdict(config), f, indent=2) |
|
|
|
|
|
print("="*80) |
|
|
print("ATTRIBUTE BINDING TRAINING") |
|
|
if config.use_min_snr: |
|
|
print(f"Using Min-SNR Weighting (gamma={config.min_snr_gamma})") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
print("\nVerifying UNet loading...") |
|
|
print("Loading base SD1.5 UNet for comparison...") |
|
|
unet_base = UNet2DConditionModel.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="unet", |
|
|
torch_dtype=torch.float32 |
|
|
).to(device) |
|
|
|
|
|
|
|
|
test_latents = torch.randn(1, 4, 64, 64, device=device) |
|
|
test_timestep = torch.tensor([500], device=device) |
|
|
test_encoder = torch.randn(1, 77, 768, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
baseline_out = unet_base(test_latents, test_timestep, encoder_hidden_states=test_encoder, return_dict=False)[0] |
|
|
|
|
|
print(f"Baseline output norm: {baseline_out.norm().item():.6f}") |
|
|
del unet_base |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
print("\nLoading Lune flow-matching model...") |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=config.base_model_repo, |
|
|
filename=config.base_checkpoint, |
|
|
repo_type="model" |
|
|
) |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", |
|
|
subfolder="unet", |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
student_dict = checkpoint["student"] |
|
|
cleaned_dict = {k[5:] if k.startswith("unet.") else k: v for k, v in student_dict.items()} |
|
|
missing, unexpected = unet.load_state_dict(cleaned_dict, strict=False) |
|
|
|
|
|
print(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
unet = unet.to(device) |
|
|
unet.requires_grad_(False) |
|
|
unet.eval() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
lune_out = unet(test_latents, test_timestep, encoder_hidden_states=test_encoder, return_dict=False)[0] |
|
|
|
|
|
print(f"Lune output norm: {lune_out.norm().item():.6f}") |
|
|
diff = (lune_out - baseline_out).abs().mean().item() |
|
|
print(f"Difference from baseline: {diff:.6f}") |
|
|
|
|
|
if diff < 1e-4: |
|
|
print("⚠️ WARNING: Outputs are nearly identical - checkpoint may not have loaded!") |
|
|
else: |
|
|
print("✓ Lune checkpoint loaded correctly (outputs differ)") |
|
|
|
|
|
print("\nLoading CLIP...") |
|
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", subfolder="tokenizer" |
|
|
) |
|
|
text_encoder = CLIPTextModel.from_pretrained( |
|
|
"runwayml/stable-diffusion-v1-5", subfolder="text_encoder", |
|
|
torch_dtype=torch.float32 |
|
|
).to(device) |
|
|
text_encoder.requires_grad_(False) |
|
|
text_encoder.eval() |
|
|
print("✓ Loaded CLIP") |
|
|
|
|
|
print(f"\nCreating LoRA (rank={config.lora_rank})...") |
|
|
|
|
|
leco_config = LECOConfig( |
|
|
lora_rank=config.lora_rank, |
|
|
lora_alpha=config.lora_alpha, |
|
|
training_method=config.training_method |
|
|
) |
|
|
|
|
|
lora_state, trainable_params = create_lora_layers(unet, leco_config) |
|
|
|
|
|
print(f"Moving LoRA parameters to {device}...") |
|
|
for param in trainable_params: |
|
|
param.data = param.data.to(device) |
|
|
|
|
|
for key, value in lora_state.items(): |
|
|
if isinstance(value, torch.Tensor) and not isinstance(value, nn.Parameter): |
|
|
lora_state[key] = value.to(device) |
|
|
|
|
|
optimizer = torch.optim.AdamW(trainable_params, lr=config.lr, weight_decay=0.01) |
|
|
|
|
|
print(f"\nTraining Configuration:") |
|
|
print(f" Attribute pairs: {len(config.attribute_pairs)}") |
|
|
for i, pair in enumerate(config.attribute_pairs[:3], 1): |
|
|
print(f" {i}. {pair.attr1} + {pair.attr2} (weight: {pair.weight})") |
|
|
if pair.negatives: |
|
|
print(f" Negatives: {len(pair.negatives)} total") |
|
|
if len(config.attribute_pairs) > 3: |
|
|
print(f" ... and {len(config.attribute_pairs)-3} more") |
|
|
|
|
|
print(f"\n Iterations: {config.iterations}") |
|
|
print(f" Pairs per batch: {config.pairs_per_batch}") |
|
|
print(f" Negatives per positive: {config.negatives_per_positive}") |
|
|
print(f" Learning rate: {config.lr}") |
|
|
print("="*80 + "\n") |
|
|
|
|
|
progress = tqdm(range(config.iterations), desc="Training") |
|
|
|
|
|
for step in progress: |
|
|
sampled_pairs = random.sample( |
|
|
config.attribute_pairs, |
|
|
min(config.pairs_per_batch, len(config.attribute_pairs)) |
|
|
) |
|
|
|
|
|
loss, metrics = compute_attribute_binding_loss_batched( |
|
|
unet, lora_state, |
|
|
sampled_pairs, |
|
|
tokenizer, text_encoder, |
|
|
config, |
|
|
device |
|
|
) |
|
|
|
|
|
loss.backward() |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0) |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
writer.add_scalar("loss/total", loss.item(), step) |
|
|
writer.add_scalar("loss/positive", metrics["positive_loss"], step) |
|
|
writer.add_scalar("loss/negative", metrics["negative_loss"], step) |
|
|
writer.add_scalar("grad_norm", grad_norm.item(), step) |
|
|
writer.add_scalar("snr_weight", metrics["snr_weight"], step) |
|
|
|
|
|
progress.set_postfix({ |
|
|
"loss": f"{loss.item():.4f}", |
|
|
"pos": f"{metrics['positive_loss']:.3f}", |
|
|
"neg": f"{metrics['negative_loss']:.3f}", |
|
|
"snr": f"{metrics['snr_weight']:.2f}", |
|
|
"grad": f"{grad_norm.item():.3f}" |
|
|
}) |
|
|
|
|
|
if (step + 1) % config.save_every == 0 or step == config.iterations - 1: |
|
|
save_dict = {} |
|
|
for key, value in lora_state.items(): |
|
|
if isinstance(value, torch.Tensor) and not key.endswith("._module"): |
|
|
save_dict[key] = value.detach().cpu() |
|
|
|
|
|
metadata = { |
|
|
"ss_network_module": "networks.lora", |
|
|
"ss_network_dim": str(config.lora_rank), |
|
|
"ss_network_alpha": str(config.lora_alpha), |
|
|
"ss_training_method": config.training_method, |
|
|
"leco_action": "attribute_binding", |
|
|
"leco_num_pairs": str(len(config.attribute_pairs)), |
|
|
"leco_step": str(step + 1), |
|
|
"leco_min_snr": str(config.use_min_snr), |
|
|
"leco_min_snr_gamma": str(config.min_snr_gamma) |
|
|
} |
|
|
|
|
|
filename = f"{config.name_prefix}_r{config.lora_rank}_s{step+1}.safetensors" |
|
|
filepath = os.path.join(output_dir, filename) |
|
|
|
|
|
save_file(save_dict, filepath, metadata=metadata) |
|
|
print(f"\n✓ Saved: {filename}") |
|
|
|
|
|
writer.close() |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("✅ Training complete!") |
|
|
print(f"Output: {output_dir}") |
|
|
print("="*80) |
|
|
|
|
|
return output_dir |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
universal_negs = ["ugly, duplicate, morbid, mutilated, blurry, fuzzy, out of frame, gross"] |
|
|
|
|
|
hair_colors = ["red hair", "blue hair", "green hair"] |
|
|
clothes = ["red shirt", "blue shirt", "green shirt"] |
|
|
|
|
|
hair_clothes_pairs = create_attribute_combinations( |
|
|
pair_attr1=hair_colors, |
|
|
pair_attr2=clothes, |
|
|
negatives=universal_negs, |
|
|
weight=1.0, |
|
|
auto_generate_negatives=True |
|
|
) |
|
|
|
|
|
print(f"Generated {len(hair_clothes_pairs)} hair+clothes pairs") |
|
|
|
|
|
|
|
|
config = AttributeBindingConfig( |
|
|
name_prefix="color_clothes_test", |
|
|
attribute_pairs=hair_clothes_pairs, |
|
|
iterations=5000, |
|
|
lora_rank=16, |
|
|
lr=2e-4, |
|
|
pairs_per_batch=4, |
|
|
negatives_per_positive=3, |
|
|
training_method="xattn", |
|
|
save_every=250, |
|
|
|
|
|
|
|
|
shift=2.5, |
|
|
min_timestep=0.0, |
|
|
max_timestep=1000.0, |
|
|
|
|
|
|
|
|
use_min_snr=True, |
|
|
min_snr_gamma=5.0 |
|
|
) |
|
|
|
|
|
train_attribute_binding(config) |