| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint. |
| | Output: Limbicnation/pixel-art-lora |
| | """ |
| |
|
| | import os |
| | import sys |
| | import torch |
| | import torch.nn.functional as F |
| | from pathlib import Path |
| | from tqdm import tqdm |
| | from PIL import Image |
| | import numpy as np |
| |
|
| | |
| | token = os.environ.get("HF_TOKEN") |
| | if not token or token == "$HF_TOKEN": |
| | print("ERROR: HF_TOKEN not set") |
| | sys.exit(1) |
| |
|
| | os.environ["HF_TOKEN"] = token |
| |
|
| | |
| | from huggingface_hub import login, hf_hub_download, snapshot_download, create_repo, upload_file |
| | from diffusers import FluxPipeline |
| | from peft import LoraConfig, get_peft_model, set_peft_model_state_dict |
| | from safetensors.torch import load_file, save_file |
| | from accelerate import Accelerator |
| |
|
| | CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500" |
| | DATASET_REPO = "Limbicnation/sprite-lora-training-data" |
| | OUTPUT_REPO = "Limbicnation/pixel-art-lora" |
| | BASE_MODEL = "black-forest-labs/FLUX.2-klein-4B" |
| |
|
| | def main(): |
| | print("="*70) |
| | print("π FLUX.2-klein-4B LoRA Training - Final") |
| | print("="*70) |
| | print(f"Base model: {BASE_MODEL}") |
| | print(f"Output: {OUTPUT_REPO}") |
| | print(f"Resume: Step 500 -> 1000") |
| | |
| | |
| | print("\nπ Authenticating...") |
| | login(token=token, add_to_git_credential=False) |
| | print("β
Authenticated") |
| | |
| | |
| | print("\nπ₯ Downloading checkpoint...") |
| | os.makedirs("checkpoint", exist_ok=True) |
| | hf_hub_download( |
| | repo_id=CHECKPOINT_REPO, |
| | filename="pytorch_lora_weights.safetensors", |
| | repo_type="model", |
| | local_dir="checkpoint", |
| | token=token |
| | ) |
| | print("β
Checkpoint downloaded") |
| | |
| | |
| | print("\nπ₯ Downloading dataset...") |
| | snapshot_download( |
| | repo_id=DATASET_REPO, |
| | repo_type="dataset", |
| | local_dir="data", |
| | token=token |
| | ) |
| | image_files = list(Path("data").rglob("*.png")) |
| | print(f"β
Dataset: {len(image_files)} images") |
| | |
| | |
| | accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16") |
| | device = accelerator.device |
| | print(f"\nβοΈ Device: {device}") |
| | |
| | |
| | print(f"\nπ₯ Loading {BASE_MODEL}...") |
| | pipe = FluxPipeline.from_pretrained( |
| | BASE_MODEL, |
| | torch_dtype=torch.bfloat16, |
| | token=token |
| | ) |
| | pipe.enable_model_cpu_offload() |
| | print("β
Model loaded") |
| | |
| | |
| | print("\nπ§ Applying LoRA (rank=64, alpha=128)...") |
| | target_modules = [] |
| | for i in range(19): |
| | target_modules.extend([ |
| | f"transformer_blocks.{i}.attn.to_q", |
| | f"transformer_blocks.{i}.attn.to_k", |
| | f"transformer_blocks.{i}.attn.to_v", |
| | ]) |
| | |
| | lora_config = LoraConfig(r=64, lora_alpha=128, target_modules=target_modules, use_rslora=True) |
| | pipe.transformer = get_peft_model(pipe.transformer, lora_config) |
| | |
| | |
| | print("\nπ Loading checkpoint...") |
| | state_dict = load_file("checkpoint/pytorch_lora_weights.safetensors") |
| | set_peft_model_state_dict(pipe.transformer, state_dict) |
| | print("β
Checkpoint loaded, resuming from step 500") |
| | |
| | global_step = 500 |
| | |
| | |
| | print(f"\nπ€ Creating output repo...") |
| | create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model", token=token) |
| | |
| | |
| | trainable = [p for p in pipe.transformer.parameters() if p.requires_grad] |
| | import bitsandbytes as bnb |
| | optimizer = bnb.optim.AdamW8bit(trainable, lr=1e-4) |
| | |
| | |
| | class Dataset(torch.utils.data.Dataset): |
| | def __init__(self, root, res=512): |
| | self.imgs = sorted(list(Path(root).rglob("*.png"))) |
| | self.res = res |
| | def __len__(self): return len(self.imgs) |
| | def __getitem__(self, idx): |
| | img = Image.open(self.imgs[idx]).convert("RGB").resize((self.res, self.res)) |
| | img = torch.from_numpy(np.array(img)).permute(2,0,1).float()/255.0 * 2 - 1 |
| | txt = self.imgs[idx].with_suffix(".txt") |
| | cap = txt.read_text().strip() if txt.exists() else "" |
| | return {"images": img, "captions": cap} |
| | |
| | dataset = Dataset("data/images") |
| | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) |
| | print(f"β
Dataset ready: {len(dataset)} images") |
| | |
| | |
| | pipe.transformer, optimizer, dataloader = accelerator.prepare( |
| | pipe.transformer, optimizer, dataloader |
| | ) |
| | |
| | |
| | print("\n" + "="*70) |
| | print("ποΈ Training: Step 500 -> 1000") |
| | print("="*70) |
| | |
| | pipe.transformer.train() |
| | pbar = tqdm(total=1000, initial=global_step, desc="Training") |
| | |
| | while global_step < 1000: |
| | for batch in dataloader: |
| | with accelerator.accumulate(pipe.transformer): |
| | imgs = batch["images"].to(device) |
| | caps = [f"pixel art sprite, {c}" for c in batch["captions"]] |
| | |
| | with torch.no_grad(): |
| | latents = pipe.vae.encode(imgs).latent_dist.sample() |
| | noise = torch.randn_like(latents) |
| | t = torch.rand(latents.shape[0], device=device) * 1000 |
| | sigmas = t.view(-1,1,1,1) / 1000 |
| | noisy = (1-sigmas)*latents + sigmas*noise |
| | target = noise - latents |
| | |
| | with torch.no_grad(): |
| | prompt_embeds = pipe.encode_prompt(caps)[0] |
| | |
| | output = pipe.transformer( |
| | hidden_states=noisy, |
| | timestep=t, |
| | encoder_hidden_states=prompt_embeds, |
| | return_dict=False |
| | )[0] |
| | |
| | loss = F.mse_loss(output.float(), target.float()) |
| | accelerator.backward(loss) |
| | |
| | if accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(pipe.transformer.parameters(), 1.0) |
| | |
| | optimizer.step() |
| | optimizer.zero_grad() |
| | |
| | if accelerator.sync_gradients: |
| | global_step += 1 |
| | pbar.update(1) |
| | pbar.set_postfix({"loss": f"{loss.item():.4f}"}) |
| | |
| | if global_step % 500 == 0: |
| | print(f"\nπΎ Saving checkpoint at step {global_step}...") |
| | os.makedirs(f"output/step_{global_step}", exist_ok=True) |
| | save_file( |
| | get_peft_model_state_dict(accelerator.unwrap_model(pipe.transformer)), |
| | f"output/step_{global_step}/pytorch_lora_weights.safetensors" |
| | ) |
| | upload_file( |
| | path_or_fileobj=f"output/step_{global_step}/pytorch_lora_weights.safetensors", |
| | path_in_repo=f"step_{global_step}/pytorch_lora_weights.safetensors", |
| | repo_id=OUTPUT_REPO, |
| | repo_type="model", |
| | token=token |
| | ) |
| | print("β
Checkpoint saved") |
| | |
| | if global_step >= 1000: |
| | break |
| | |
| | pbar.close() |
| | |
| | |
| | print("\nπΎ Saving final model...") |
| | os.makedirs("output/final", exist_ok=True) |
| | save_file( |
| | get_peft_model_state_dict(accelerator.unwrap_model(pipe.transformer)), |
| | "output/final/pytorch_lora_weights.safetensors" |
| | ) |
| | upload_file( |
| | path_or_fileobj="output/final/pytorch_lora_weights.safetensors", |
| | path_in_repo="pytorch_lora_weights.safetensors", |
| | repo_id=OUTPUT_REPO, |
| | repo_type="model", |
| | token=token |
| | ) |
| | |
| | print("\n" + "="*70) |
| | print("β
Training Complete!") |
| | print("="*70) |
| | print(f"\nπ€ Model: https://huggingface.co/{OUTPUT_REPO}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|