Spaces:
Paused
Paused
| # train.py | |
| import os | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| from peft import LoraConfig, get_peft_model | |
| # ββ 1) Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Where you put your images + prompts | |
| DATA_DIR = os.getenv("DATA_DIR", "./data") | |
| # Where your base model lives (downloaded or cached) | |
| MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model") | |
| # Where to save your LoRAβfineβtuned model | |
| OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained") | |
| # ββ 2) Prepare the base model snapshot ββββββββββββββββββββββββββββββββββββββββ | |
| print(f"π Loading dataset from: {DATA_DIR}") | |
| print("π₯ Fetching or verifying base model: HiDream-ai/HiDream-I1-Dev") | |
| # If youβve preβdownloaded into MODEL_DIR, just use it; otherwise pull from HF Hub | |
| if not os.path.isdir(MODEL_DIR): | |
| MODEL_DIR = snapshot_download( | |
| repo_id="HiDream-ai/HiDream-I1-Dev", | |
| local_dir=MODEL_DIR | |
| ) | |
| # ββ 3) Load the scheduler manually βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Diffusersβ scheduler config JSON points at FlowMatchLCMScheduler, | |
| # but your installed version doesnβt have that class. Instead we | |
| # forceβload DPMSolverMultistepScheduler via `from_pretrained`. | |
| print(f"π Loading scheduler from: {MODEL_DIR}/scheduler") | |
| scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
| pretrained_model_name_or_path=MODEL_DIR, | |
| subfolder="scheduler" | |
| ) | |
| # ββ 4) Build the Stable Diffusion pipeline ββββββββββββββββββββββββββββββββββββ | |
| print("π§ Creating StableDiffusionPipeline with custom scheduler") | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| pretrained_model_name_or_path=MODEL_DIR, | |
| scheduler=scheduler, | |
| torch_dtype=torch.float16, | |
| ).to("cuda") | |
| # ββ 5) Apply PEFT LoRA adapters βββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("π§ Configuring LoRA adapter on UβNet") | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=16, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| pipe.unet = get_peft_model(pipe.unet, lora_config) | |
| # ββ 6) (Placeholder) Simulate your training loop βββββββββββββββββββββββββββββ | |
| print("π Starting fineβtuning loop (simulated)") | |
| for step in range(100): | |
| # Here you'd load your data, compute loss, do optimizer.step(), etc. | |
| print(f" Training step {step+1}/100") | |
| # ββ 7) Save your LoRAβtuned model ββββββββββββββββββββββββββββββββββββββββββββ | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| pipe.save_pretrained(OUTPUT_DIR) | |
| print("β Training complete. Model saved to", OUTPUT_DIR) | |