Spaces:
Paused
Paused
| import os | |
| import torch | |
| from aitoolkit import ( | |
| LoRATrainer, | |
| StableDiffusionModel, | |
| LoRAConfig, | |
| ImageTextDataset, | |
| ) | |
| # 1. Configuration | |
| MODEL_ID = "HiDream-ai/HiDream-I1-Dev" # or your gated FLUX model if you have access | |
| DATA_DIR = "/workspace/data" | |
| OUTPUT_DIR = "/workspace/lora-trained" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| lora_cfg = LoRAConfig( | |
| rank=16, | |
| alpha=16, | |
| bias="none", | |
| ) | |
| training_args = { | |
| "num_train_steps": 100, | |
| "batch_size": 4, | |
| "learning_rate": 1e-4, | |
| "save_every_n_steps": 50, | |
| "output_dir": OUTPUT_DIR, | |
| } | |
| # 2. Load base diffusion model | |
| model = StableDiffusionModel.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| device=DEVICE, | |
| use_auth_token=True, # if it’s a gated repo | |
| ) | |
| # 3. Prepare your dataset | |
| # Expects pairs of image files + .txt captions in DATA_DIR | |
| dataset = ImageTextDataset(data_root=DATA_DIR, image_size=512) | |
| # 4. Hook up the LoRA adapter | |
| model.apply_lora(lora_cfg) | |
| # 5. Create the trainer and kickoff | |
| trainer = LoRATrainer( | |
| model=model, | |
| dataset=dataset, | |
| args=training_args, | |
| ) | |
| print("🚀 Starting training with AI‑Toolkit…") | |
| trainer.train() | |
| print(f"✅ Done! Fine-tuned weights saved to {OUTPUT_DIR}") | |