# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.0.0", # "diffusers>=0.25.0", # "transformers>=4.35.0", # "accelerate>=0.24.0", # "peft>=0.7.0", # "huggingface-hub>=0.20.0", # "safetensors>=0.4.0", # "Pillow>=10.0.0", # "numpy>=1.24.0", # "tqdm>=4.66.0", # ] # /// """ Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint. Output: Limbicnation/pixel-art-lora """ import os import sys import torch import torch.nn.functional as F from pathlib import Path from tqdm import tqdm from PIL import Image import numpy as np # Get token token = os.environ.get("HF_TOKEN") if not token or token == "$HF_TOKEN": print("ERROR: HF_TOKEN not set") sys.exit(1) os.environ["HF_TOKEN"] = token # Import after setting token from huggingface_hub import login, hf_hub_download, snapshot_download, create_repo, upload_file from diffusers import FluxPipeline from peft import LoraConfig, get_peft_model, set_peft_model_state_dict from safetensors.torch import load_file, save_file from accelerate import Accelerator CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500" DATASET_REPO = "Limbicnation/sprite-lora-training-data" OUTPUT_REPO = "Limbicnation/pixel-art-lora" BASE_MODEL = "black-forest-labs/FLUX.2-klein-4B" def main(): print("="*70) print("šŸš€ FLUX.2-klein-4B LoRA Training - Final") print("="*70) print(f"Base model: {BASE_MODEL}") print(f"Output: {OUTPUT_REPO}") print(f"Resume: Step 500 -> 1000") # Login print("\nšŸ”‘ Authenticating...") login(token=token, add_to_git_credential=False) print("āœ… Authenticated") # Download checkpoint print("\nšŸ“„ Downloading checkpoint...") os.makedirs("checkpoint", exist_ok=True) hf_hub_download( repo_id=CHECKPOINT_REPO, filename="pytorch_lora_weights.safetensors", repo_type="model", local_dir="checkpoint", token=token ) print("āœ… Checkpoint downloaded") # Download dataset print("\nšŸ“„ Downloading dataset...") snapshot_download( repo_id=DATASET_REPO, repo_type="dataset", local_dir="data", token=token ) image_files = list(Path("data").rglob("*.png")) print(f"āœ… Dataset: {len(image_files)} images") # Setup accelerator accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16") device = accelerator.device print(f"\nāš™ļø Device: {device}") # Load model print(f"\nšŸ“„ Loading {BASE_MODEL}...") pipe = FluxPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, token=token ) pipe.enable_model_cpu_offload() print("āœ… Model loaded") # Apply LoRA print("\nšŸ”§ Applying LoRA (rank=64, alpha=128)...") target_modules = [] for i in range(19): target_modules.extend([ f"transformer_blocks.{i}.attn.to_q", f"transformer_blocks.{i}.attn.to_k", f"transformer_blocks.{i}.attn.to_v", ]) lora_config = LoraConfig(r=64, lora_alpha=128, target_modules=target_modules, use_rslora=True) pipe.transformer = get_peft_model(pipe.transformer, lora_config) # Load checkpoint print("\nšŸ”„ Loading checkpoint...") state_dict = load_file("checkpoint/pytorch_lora_weights.safetensors") set_peft_model_state_dict(pipe.transformer, state_dict) print("āœ… Checkpoint loaded, resuming from step 500") global_step = 500 # Create output repo print(f"\nšŸ“¤ Creating output repo...") create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model", token=token) # Setup optimizer trainable = [p for p in pipe.transformer.parameters() if p.requires_grad] import bitsandbytes as bnb optimizer = bnb.optim.AdamW8bit(trainable, lr=1e-4) # Dataset class Dataset(torch.utils.data.Dataset): def __init__(self, root, res=512): self.imgs = sorted(list(Path(root).rglob("*.png"))) self.res = res def __len__(self): return len(self.imgs) def __getitem__(self, idx): img = Image.open(self.imgs[idx]).convert("RGB").resize((self.res, self.res)) img = torch.from_numpy(np.array(img)).permute(2,0,1).float()/255.0 * 2 - 1 txt = self.imgs[idx].with_suffix(".txt") cap = txt.read_text().strip() if txt.exists() else "" return {"images": img, "captions": cap} dataset = Dataset("data/images") dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) print(f"āœ… Dataset ready: {len(dataset)} images") # Prepare pipe.transformer, optimizer, dataloader = accelerator.prepare( pipe.transformer, optimizer, dataloader ) # Training print("\n" + "="*70) print("šŸ‹ļø Training: Step 500 -> 1000") print("="*70) pipe.transformer.train() pbar = tqdm(total=1000, initial=global_step, desc="Training") while global_step < 1000: for batch in dataloader: with accelerator.accumulate(pipe.transformer): imgs = batch["images"].to(device) caps = [f"pixel art sprite, {c}" for c in batch["captions"]] with torch.no_grad(): latents = pipe.vae.encode(imgs).latent_dist.sample() noise = torch.randn_like(latents) t = torch.rand(latents.shape[0], device=device) * 1000 sigmas = t.view(-1,1,1,1) / 1000 noisy = (1-sigmas)*latents + sigmas*noise target = noise - latents with torch.no_grad(): prompt_embeds = pipe.encode_prompt(caps)[0] output = pipe.transformer( hidden_states=noisy, timestep=t, encoder_hidden_states=prompt_embeds, return_dict=False )[0] loss = F.mse_loss(output.float(), target.float()) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(pipe.transformer.parameters(), 1.0) optimizer.step() optimizer.zero_grad() if accelerator.sync_gradients: global_step += 1 pbar.update(1) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) if global_step % 500 == 0: print(f"\nšŸ’¾ Saving checkpoint at step {global_step}...") os.makedirs(f"output/step_{global_step}", exist_ok=True) save_file( get_peft_model_state_dict(accelerator.unwrap_model(pipe.transformer)), f"output/step_{global_step}/pytorch_lora_weights.safetensors" ) upload_file( path_or_fileobj=f"output/step_{global_step}/pytorch_lora_weights.safetensors", path_in_repo=f"step_{global_step}/pytorch_lora_weights.safetensors", repo_id=OUTPUT_REPO, repo_type="model", token=token ) print("āœ… Checkpoint saved") if global_step >= 1000: break pbar.close() # Final save print("\nšŸ’¾ Saving final model...") os.makedirs("output/final", exist_ok=True) save_file( get_peft_model_state_dict(accelerator.unwrap_model(pipe.transformer)), "output/final/pytorch_lora_weights.safetensors" ) upload_file( path_or_fileobj="output/final/pytorch_lora_weights.safetensors", path_in_repo="pytorch_lora_weights.safetensors", repo_id=OUTPUT_REPO, repo_type="model", token=token ) print("\n" + "="*70) print("āœ… Training Complete!") print("="*70) print(f"\nšŸ“¤ Model: https://huggingface.co/{OUTPUT_REPO}") if __name__ == "__main__": main()