import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from pathlib import Path import argparse from tqdm import tqdm from safetensors.torch import save_file, load_file from collections import deque from model import LocalSongModel HARDCODED_TAGS = [1908] torch.set_float32_matmul_precision('high') class LoRALinear(nn.Module): def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0): super().__init__() self.original_linear = original_linear self.rank = rank self.alpha = alpha self.scaling = alpha / rank self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank)) self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features)) nn.init.kaiming_uniform_(self.lora_A, a=5**0.5) nn.init.zeros_(self.lora_B) self.original_linear.weight.requires_grad = False if self.original_linear.bias is not None: self.original_linear.bias.requires_grad = False def forward(self, x): result = self.original_linear(x) lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling return result + lora_out def inject_lora(model: LocalSongModel, rank: int = 8, alpha: float = 16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None): """Inject LoRA layers into the model.""" lora_modules = [] if device is None: device = next(model.parameters()).device for name, module in model.named_modules(): if isinstance(module, nn.Linear): if any(target in name for target in target_modules): *parent_path, attr_name = name.split('.') parent = model for p in parent_path: parent = getattr(parent, p) lora_layer = LoRALinear(module, rank=rank, alpha=alpha) lora_layer.lora_A.data = lora_layer.lora_A.data.to(device) lora_layer.lora_B.data = lora_layer.lora_B.data.to(device) setattr(parent, attr_name, lora_layer) lora_modules.append(name) print(f"Injected LoRA into {len(lora_modules)} layers:") for name in lora_modules[:5]: print(f" - {name}") if len(lora_modules) > 5: print(f" ... and {len(lora_modules) - 5} more") return model def get_lora_parameters(model): """Extract only LoRA parameters for optimization.""" lora_params = [] for module in model.modules(): if isinstance(module, LoRALinear): lora_params.extend([module.lora_A, module.lora_B]) return lora_params def save_lora_weights(model, output_path): """Save LoRA weights to a safetensors file.""" lora_state_dict = {} for name, module in model.named_modules(): if isinstance(module, LoRALinear): lora_state_dict[f"{name}.lora_A"] = module.lora_A lora_state_dict[f"{name}.lora_B"] = module.lora_B save_file(lora_state_dict, output_path) print(f"Saved {len(lora_state_dict)} LoRA parameters to {output_path}") class LatentDataset(Dataset): """Dataset for pre-encoded latents.""" def __init__(self, latents_dir: str): self.latents_dir = Path(latents_dir) self.latent_files = sorted(list(self.latents_dir.glob("*.pt"))) if len(self.latent_files) == 0: raise ValueError(f"No .pt files found in {latents_dir}") print(f"Found {len(self.latent_files)} latent files") def __len__(self): return len(self.latent_files) def __getitem__(self, idx): latent = torch.load(self.latent_files[idx]) if latent.ndim == 3: latent = latent.unsqueeze(0) return latent class RectifiedFlow: """Simplified rectified flow matching.""" def __init__(self, model): self.model = model def forward(self, x, cond): """Compute flow matching loss.""" b = x.size(0) nt = torch.randn((b,), device=x.device) t = torch.sigmoid(nt) texp = t.view([b, *([1] * len(x.shape[1:]))]) z1 = torch.randn_like(x) zt = (1 - texp) * x + texp * z1 vtheta = self.model(zt, t, cond) target = z1 - x loss = ((vtheta - target) ** 2).mean() return loss def collate_fn(batch, subsection_length=1024): """Custom collate function to sample random subsections.""" sampled_latents = [] for latent in batch: if latent.ndim == 3: latent = latent.unsqueeze(0) _, _, _, width = latent.shape if width < subsection_length: # Pad if too short pad_amount = subsection_length - width latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) else: # Randomly sample subsection max_start = width - subsection_length start_idx = torch.randint(0, max_start + 1, (1,)).item() latent = latent[:, :, :, start_idx:start_idx + subsection_length] sampled_latents.append(latent.squeeze(0)) batch_latents = torch.stack(sampled_latents) batch_tags = [HARDCODED_TAGS] * len(batch) return batch_latents, batch_tags def main(): parser = argparse.ArgumentParser(description='LoRA training for LocalSong model with embedding training') parser.add_argument('--latents_dir', type=str, required=True, help='Directory containing VAE-encoded latents (.pt files)') parser.add_argument('--checkpoint', type=str, default='checkpoints/checkpoint_461260.safetensors', help='Path to base model checkpoint') parser.add_argument('--lora_rank', type=int, default=16, help='LoRA rank') parser.add_argument('--lora_alpha', type=float, default=16, help='LoRA alpha (scaling factor)') parser.add_argument('--batch_size', type=int, default=16, help='Batch size') parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate') parser.add_argument('--steps', type=int, default=1500, help='Number of training steps') parser.add_argument('--subsection_length', type=int, default=512, help='Latent subsection length') parser.add_argument('--output', type=str, default='lora.safetensors', help='Output path for LoRA weights') parser.add_argument('--save_every', type=int, default=500, help='Save checkpoint every N steps') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") print(f"Using hardcoded tags: {HARDCODED_TAGS}") print(f"Loading base model from {args.checkpoint}") model = LocalSongModel( in_channels=8, num_groups=16, hidden_size=1024, decoder_hidden_size=2048, num_blocks=36, patch_size=(16, 1), num_classes=2304, max_tags=8, ) print(f"Loading checkpoint from {args.checkpoint}") state_dict = load_file(args.checkpoint) model.load_state_dict(state_dict, strict=True) print("Base model loaded") model = model.to(device) model = inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, device=device) model.train() lora_params = get_lora_parameters(model) optimizer = optim.Adam(lora_params, lr=args.lr) print(f"Training {len(lora_params)} LoRA parameters") dataset = LatentDataset(args.latents_dir) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=lambda batch: collate_fn(batch, args.subsection_length) ) rf = RectifiedFlow(model) print("\nStarting training...") step = 0 pbar = tqdm(total=args.steps, desc="Training") loss_history = deque(maxlen=50) while step < args.steps: for batch_latents, batch_tags in dataloader: batch_latents = batch_latents.to(device) optimizer.zero_grad() loss = rf.forward(batch_latents, batch_tags) loss.backward() torch.nn.utils.clip_grad_norm_(lora_params, 1.0) optimizer.step() # Track loss and compute average loss_history.append(loss.item()) avg_loss = sum(loss_history) / len(loss_history) pbar.set_postfix({'loss': f'{avg_loss:.4f}'}) pbar.update(1) step += 1 if step % args.save_every == 0: save_path = args.output.replace('.safetensors', f'_step{step}.safetensors') save_lora_weights(model, save_path) if step >= args.steps: break save_lora_weights(model, args.output) print(f"\nTraining complete! LoRA weights saved to {args.output}") if __name__ == '__main__': main()