Spaces:
Paused
Paused
| import os | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from diffusers import ( | |
| StableDiffusionPipeline, | |
| DPMSolverMultistepScheduler, | |
| AutoencoderKL, | |
| UNet2DConditionModel | |
| ) | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from peft import LoraConfig, get_peft_model | |
| MODEL_ID = "black-forest-labs/FLUX.1-dev" | |
| # download | |
| model_path = snapshot_download( | |
| MODEL_ID, | |
| local_dir="./fluxdev-model", | |
| use_auth_token=True | |
| ) | |
| # later loading | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16, | |
| use_auth_token=True | |
| ).to("cuda") | |
| # 1) grab the model locally | |
| print("📥 Downloading Flux‑Dev model…") | |
| model_path = snapshot_download(MODEL_ID, local_dir="./fluxdev-model") | |
| # 2) load each piece with its correct subfolder | |
| print("🔄 Loading scheduler…") | |
| scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
| model_path, subfolder="scheduler" | |
| ) | |
| print("🔄 Loading VAE…") | |
| vae = AutoencoderKL.from_pretrained( | |
| model_path, subfolder="vae", torch_dtype=torch.float16 | |
| ) | |
| print("🔄 Loading text encoder + tokenizer…") | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| model_path, subfolder="text_encoder", torch_dtype=torch.float16 | |
| ) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| model_path, subfolder="tokenizer" | |
| ) | |
| print("🔄 Loading U‑Net…") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| model_path, subfolder="unet", torch_dtype=torch.float16 | |
| ) | |
| # 3) assemble the pipeline | |
| print("🛠 Assembling pipeline…") | |
| pipe = StableDiffusionPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler | |
| ).to("cuda") | |
| # 4) apply LoRA | |
| print("🧠 Applying LoRA…") | |
| lora_config = LoraConfig(r=16, lora_alpha=16, bias="none", task_type="CAUSAL_LM") | |
| pipe.unet = get_peft_model(pipe.unet, lora_config) | |
| # 5) your training loop (or dummy loop for illustration) | |
| print("🚀 Starting fine‑tuning…") | |
| for step in range(100): | |
| print(f"Training step {step+1}/100") | |
| # …insert your actual data‑loader and loss/backprop here… | |
| os.makedirs(output_dir, exist_ok=True) | |
| pipe.save_pretrained(output_dir) | |
| print("✅ Done. LoRA weights in", output_dir) | |