Spaces:
Paused
Paused
| # train.py | |
| import os | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from peft import LoraConfig, get_peft_model | |
| # 1οΈβ£ Pick your scheduler class | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| DPMSolverMultistepScheduler, | |
| UNet2DConditionModel, | |
| AutoencoderKL, | |
| ) | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| # βββ 1) CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DATA_DIR = os.getenv("DATA_DIR", "./data") | |
| MODEL_DIR = os.getenv("MODEL_DIR", "./hidream-model") | |
| OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained") | |
| # βββ 2) DOWNLOAD OR VERIFY BASE MODEL ββββββββββββββββββββββββββββββββββββββββββ | |
| if not os.path.isdir(MODEL_DIR): | |
| MODEL_DIR = snapshot_download( | |
| repo_id="HiDream-ai/HiDream-I1-Dev", | |
| local_dir=MODEL_DIR | |
| ) | |
| # βββ 3) LOAD EACH PIPELINE COMPONENT ββββββββββββββββββββββββββββββββββββββββββ | |
| # 3a) Scheduler | |
| scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
| MODEL_DIR, | |
| subfolder="scheduler" | |
| ) | |
| # 3b) VAE | |
| vae = AutoencoderKL.from_pretrained( | |
| MODEL_DIR, | |
| subfolder="vae", | |
| torch_dtype=torch.float16 | |
| ).to("cuda") | |
| # 3c) Text encoder + tokenizer | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| MODEL_DIR, | |
| subfolder="text_encoder", | |
| torch_dtype=torch.float16 | |
| ).to("cuda") | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| MODEL_DIR, | |
| subfolder="tokenizer" | |
| ) | |
| # 3d) UβNet | |
| unet = UNet2DConditionModel.from_pretrained( | |
| MODEL_DIR, | |
| subfolder="unet", | |
| torch_dtype=torch.float16 | |
| ).to("cuda") | |
| # βββ 4) BUILD THE PIPELINE ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| pipe = StableDiffusionPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| ).to("cuda") | |
| # βββ 5) APPLY LORA ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=16, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| pipe.unet = get_peft_model(pipe.unet, lora_config) | |
| # βββ 6) TRAINING LOOP (SIMULATED) βββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"π Data at {DATA_DIR}") | |
| for step in range(100): | |
| # β¦ your real data loading + optimizer here β¦ | |
| print(f"Training step {step+1}/100") | |
| # βββ 7) SAVE THE FINEβTUNED LOβRA βββββββββββββββββββββββββββββββββββββββββββββ | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| pipe.save_pretrained(OUTPUT_DIR) | |
| print("β Done! Saved to", OUTPUT_DIR) | |