|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from pathlib import Path |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
from safetensors.torch import save_file, load_file |
|
|
from collections import deque |
|
|
from model import LocalSongModel |
|
|
|
|
|
HARDCODED_TAGS = [1908] |
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
class LoRALinear(nn.Module): |
|
|
def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0): |
|
|
super().__init__() |
|
|
self.original_linear = original_linear |
|
|
self.rank = rank |
|
|
self.alpha = alpha |
|
|
self.scaling = alpha / rank |
|
|
|
|
|
self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank)) |
|
|
self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features)) |
|
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5) |
|
|
nn.init.zeros_(self.lora_B) |
|
|
|
|
|
self.original_linear.weight.requires_grad = False |
|
|
if self.original_linear.bias is not None: |
|
|
self.original_linear.bias.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
|
result = self.original_linear(x) |
|
|
lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling |
|
|
return result + lora_out |
|
|
|
|
|
def inject_lora(model: LocalSongModel, rank: int = 8, alpha: float = 16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None): |
|
|
"""Inject LoRA layers into the model.""" |
|
|
|
|
|
lora_modules = [] |
|
|
|
|
|
if device is None: |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
|
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
if any(target in name for target in target_modules): |
|
|
|
|
|
*parent_path, attr_name = name.split('.') |
|
|
parent = model |
|
|
for p in parent_path: |
|
|
parent = getattr(parent, p) |
|
|
|
|
|
lora_layer = LoRALinear(module, rank=rank, alpha=alpha) |
|
|
|
|
|
lora_layer.lora_A.data = lora_layer.lora_A.data.to(device) |
|
|
lora_layer.lora_B.data = lora_layer.lora_B.data.to(device) |
|
|
setattr(parent, attr_name, lora_layer) |
|
|
lora_modules.append(name) |
|
|
|
|
|
print(f"Injected LoRA into {len(lora_modules)} layers:") |
|
|
for name in lora_modules[:5]: |
|
|
print(f" - {name}") |
|
|
if len(lora_modules) > 5: |
|
|
print(f" ... and {len(lora_modules) - 5} more") |
|
|
|
|
|
return model |
|
|
|
|
|
def get_lora_parameters(model): |
|
|
"""Extract only LoRA parameters for optimization.""" |
|
|
lora_params = [] |
|
|
for module in model.modules(): |
|
|
if isinstance(module, LoRALinear): |
|
|
lora_params.extend([module.lora_A, module.lora_B]) |
|
|
return lora_params |
|
|
|
|
|
def save_lora_weights(model, output_path): |
|
|
"""Save LoRA weights to a safetensors file.""" |
|
|
lora_state_dict = {} |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, LoRALinear): |
|
|
lora_state_dict[f"{name}.lora_A"] = module.lora_A |
|
|
lora_state_dict[f"{name}.lora_B"] = module.lora_B |
|
|
|
|
|
save_file(lora_state_dict, output_path) |
|
|
print(f"Saved {len(lora_state_dict)} LoRA parameters to {output_path}") |
|
|
|
|
|
class LatentDataset(Dataset): |
|
|
"""Dataset for pre-encoded latents.""" |
|
|
|
|
|
def __init__(self, latents_dir: str): |
|
|
self.latents_dir = Path(latents_dir) |
|
|
|
|
|
self.latent_files = sorted(list(self.latents_dir.glob("*.pt"))) |
|
|
|
|
|
if len(self.latent_files) == 0: |
|
|
raise ValueError(f"No .pt files found in {latents_dir}") |
|
|
|
|
|
print(f"Found {len(self.latent_files)} latent files") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.latent_files) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
latent = torch.load(self.latent_files[idx]) |
|
|
|
|
|
if latent.ndim == 3: |
|
|
latent = latent.unsqueeze(0) |
|
|
|
|
|
return latent |
|
|
|
|
|
class RectifiedFlow: |
|
|
"""Simplified rectified flow matching.""" |
|
|
|
|
|
def __init__(self, model): |
|
|
self.model = model |
|
|
|
|
|
def forward(self, x, cond): |
|
|
"""Compute flow matching loss.""" |
|
|
b = x.size(0) |
|
|
|
|
|
nt = torch.randn((b,), device=x.device) |
|
|
t = torch.sigmoid(nt) |
|
|
|
|
|
texp = t.view([b, *([1] * len(x.shape[1:]))]) |
|
|
z1 = torch.randn_like(x) |
|
|
zt = (1 - texp) * x + texp * z1 |
|
|
|
|
|
vtheta = self.model(zt, t, cond) |
|
|
|
|
|
target = z1 - x |
|
|
loss = ((vtheta - target) ** 2).mean() |
|
|
|
|
|
return loss |
|
|
|
|
|
def collate_fn(batch, subsection_length=1024): |
|
|
"""Custom collate function to sample random subsections.""" |
|
|
sampled_latents = [] |
|
|
|
|
|
for latent in batch: |
|
|
if latent.ndim == 3: |
|
|
latent = latent.unsqueeze(0) |
|
|
|
|
|
_, _, _, width = latent.shape |
|
|
|
|
|
if width < subsection_length: |
|
|
|
|
|
pad_amount = subsection_length - width |
|
|
latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) |
|
|
else: |
|
|
|
|
|
max_start = width - subsection_length |
|
|
start_idx = torch.randint(0, max_start + 1, (1,)).item() |
|
|
latent = latent[:, :, :, start_idx:start_idx + subsection_length] |
|
|
|
|
|
sampled_latents.append(latent.squeeze(0)) |
|
|
|
|
|
batch_latents = torch.stack(sampled_latents) |
|
|
|
|
|
batch_tags = [HARDCODED_TAGS] * len(batch) |
|
|
|
|
|
return batch_latents, batch_tags |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='LoRA training for LocalSong model with embedding training') |
|
|
|
|
|
parser.add_argument('--latents_dir', type=str, required=True, |
|
|
help='Directory containing VAE-encoded latents (.pt files)') |
|
|
|
|
|
parser.add_argument('--checkpoint', type=str, default='checkpoints/checkpoint_461260.safetensors', |
|
|
help='Path to base model checkpoint') |
|
|
parser.add_argument('--lora_rank', type=int, default=16, |
|
|
help='LoRA rank') |
|
|
parser.add_argument('--lora_alpha', type=float, default=16, |
|
|
help='LoRA alpha (scaling factor)') |
|
|
parser.add_argument('--batch_size', type=int, default=16, |
|
|
help='Batch size') |
|
|
parser.add_argument('--lr', type=float, default=2e-4, |
|
|
help='Learning rate') |
|
|
parser.add_argument('--steps', type=int, default=1500, |
|
|
help='Number of training steps') |
|
|
parser.add_argument('--subsection_length', type=int, default=512, |
|
|
help='Latent subsection length') |
|
|
parser.add_argument('--output', type=str, default='lora.safetensors', |
|
|
help='Output path for LoRA weights') |
|
|
parser.add_argument('--save_every', type=int, default=500, |
|
|
help='Save checkpoint every N steps') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
print(f"Using hardcoded tags: {HARDCODED_TAGS}") |
|
|
|
|
|
print(f"Loading base model from {args.checkpoint}") |
|
|
model = LocalSongModel( |
|
|
in_channels=8, |
|
|
num_groups=16, |
|
|
hidden_size=1024, |
|
|
decoder_hidden_size=2048, |
|
|
num_blocks=36, |
|
|
patch_size=(16, 1), |
|
|
num_classes=2304, |
|
|
max_tags=8, |
|
|
) |
|
|
|
|
|
print(f"Loading checkpoint from {args.checkpoint}") |
|
|
state_dict = load_file(args.checkpoint) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
print("Base model loaded") |
|
|
|
|
|
model = model.to(device) |
|
|
model = inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, device=device) |
|
|
|
|
|
model.train() |
|
|
|
|
|
lora_params = get_lora_parameters(model) |
|
|
optimizer = optim.Adam(lora_params, lr=args.lr) |
|
|
print(f"Training {len(lora_params)} LoRA parameters") |
|
|
|
|
|
dataset = LatentDataset(args.latents_dir) |
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=0, |
|
|
collate_fn=lambda batch: collate_fn(batch, args.subsection_length) |
|
|
) |
|
|
|
|
|
rf = RectifiedFlow(model) |
|
|
|
|
|
print("\nStarting training...") |
|
|
step = 0 |
|
|
pbar = tqdm(total=args.steps, desc="Training") |
|
|
|
|
|
loss_history = deque(maxlen=50) |
|
|
|
|
|
while step < args.steps: |
|
|
for batch_latents, batch_tags in dataloader: |
|
|
batch_latents = batch_latents.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss = rf.forward(batch_latents, batch_tags) |
|
|
|
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(lora_params, 1.0) |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
loss_history.append(loss.item()) |
|
|
avg_loss = sum(loss_history) / len(loss_history) |
|
|
|
|
|
pbar.set_postfix({'loss': f'{avg_loss:.4f}'}) |
|
|
pbar.update(1) |
|
|
step += 1 |
|
|
|
|
|
if step % args.save_every == 0: |
|
|
save_path = args.output.replace('.safetensors', f'_step{step}.safetensors') |
|
|
save_lora_weights(model, save_path) |
|
|
|
|
|
if step >= args.steps: |
|
|
break |
|
|
|
|
|
save_lora_weights(model, args.output) |
|
|
print(f"\nTraining complete! LoRA weights saved to {args.output}") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|