Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusers.models import UNet2DConditionModel, AutoencoderKL | |
| from diffusers.schedulers import DDPMScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from typing import Dict, Any, Optional | |
| class MiniDiffusionPipeline: | |
| # config mặc định | |
| DEFAULT_CONFIG: Dict[str, Any] = { | |
| "beta_schedule": "scaled_linear", | |
| "beta_start": 0.00085, | |
| "beta_end": 0.0120, | |
| "num_train_timesteps": 1000, | |
| "prediction_type": "epsilon", | |
| "variance_type": "fixed_small", | |
| "clip_sample": False, | |
| "rescale_betas_zero_snr": False, | |
| "timestep_spacing": "leading", | |
| "lr": 1e-4, | |
| "optimizer": "AdamW", | |
| "scheduler": "cosine", | |
| "ema_decay": 0.9999, | |
| "latent_scale": 0.18215, | |
| "text_embed_dim": 768, | |
| "latent_channels": 4, | |
| "latent_downscale_factor": 8, | |
| # --- Cấu hình kiến trúc UNet-mini --- | |
| "image_size": 128, | |
| "unet_block_out_channels": (256, 512, 1024), | |
| "unet_layers_per_block": 1, | |
| "unet_down_block_types": ( | |
| "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", | |
| "DownBlock2D", | |
| ), | |
| "unet_up_block_types": ( | |
| "UpBlock2D", | |
| "CrossAttnUpBlock2D", | |
| "CrossAttnUpBlock2D", | |
| ), | |
| "unet_mid_block_type": "UNetMidBlock2DCrossAttn", | |
| "unet_attention_head_dim": 8, | |
| } | |
| def __init__( | |
| self, | |
| base_model_id: str = "stabilityai/stable-diffusion-v1-5", | |
| vae_model_id: Optional[str] = None, | |
| device: str = "cpu", | |
| config_overrides: Optional[Dict[str, Any]] = None | |
| ): | |
| self.device = torch.device(device) | |
| self.config = {**self.DEFAULT_CONFIG, **(config_overrides or {})} | |
| print(f"Đang tải Tokenizer và Text Encoder (đã đóng băng) từ {base_model_id}...") | |
| self.tokenizer = self._load_tokenizer(base_model_id) | |
| self.text_encoder = self._load_text_encoder(base_model_id) | |
| _vae_id = vae_model_id or base_model_id | |
| _vae_subfolder = "vae" if vae_model_id is None else None | |
| print(f"Đang tải VAE (để fine-tune) từ {_vae_id}...") | |
| self.vae = self._load_vae(_vae_id, _vae_subfolder) | |
| print("Khởi tạo UNet-mini (với trọng số ngẫu nhiên)...") | |
| self.unet = self._load_mini_unet() | |
| print("Khởi tạo Noise Scheduler...") | |
| self.noise_scheduler = self._load_noise_scheduler() | |
| print("\n--- MiniDiffusionPipeline đã sẵn sàng! ---") | |
| self.print_model_stats() | |
| def _load_tokenizer(self, model_id: str) -> CLIPTokenizer: | |
| return CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
| def _load_text_encoder(self, model_id: str) -> CLIPTextModel: | |
| model = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") | |
| model.to(self.device) | |
| model.requires_grad_(False) | |
| return model | |
| def _load_vae(self, model_id: str, subfolder: Optional[str]) -> AutoencoderKL: | |
| if subfolder: | |
| model = AutoencoderKL.from_pretrained(model_id, subfolder=subfolder) | |
| else: | |
| model = AutoencoderKL.from_pretrained(model_id) | |
| model.to(self.device) | |
| return model | |
| def _load_mini_unet(self) -> UNet2DConditionModel: | |
| latent_size = self.config["image_size"] // self.config["latent_downscale_factor"] | |
| unet_config = { | |
| "sample_size": latent_size, | |
| "in_channels": self.config["latent_channels"], | |
| "out_channels": self.config["latent_channels"], | |
| "block_out_channels": self.config["unet_block_out_channels"], | |
| "layers_per_block": self.config["unet_layers_per_block"], | |
| "down_block_types": self.config["unet_down_block_types"], | |
| "up_block_types": self.config["unet_up_block_types"], | |
| "mid_block_type": self.config["unet_mid_block_type"], | |
| "cross_attention_dim": self.config["text_embed_dim"], | |
| "attention_head_dim": self.config["unet_attention_head_dim"], | |
| } | |
| model = UNet2DConditionModel(**unet_config) | |
| model.to(self.device) | |
| return model | |
| def _load_noise_scheduler(self) -> DDPMScheduler: | |
| return DDPMScheduler.from_config(self.config) | |
| def print_model_stats(self): | |
| unet_params = sum(p.numel() for p in self.unet.parameters() if p.requires_grad) | |
| vae_params = sum(p.numel() for p in self.vae.parameters() if p.requires_grad) | |
| print(f" UNet-mini (để train): {unet_params / 1_000_000:.2f} triệu tham số") | |
| print(f" VAE (để fine-tune): {vae_params / 1_000_000:.2f} triệu tham số") | |
| def get_trainable_parameters(self) -> Dict[str, Any]: | |
| return { | |
| "unet": self.unet.parameters(), | |
| "vae": self.vae.parameters() | |
| } | |
| # --- KHỐI KIỂM THỬ (SMOKE TEST) --- | |
| def _run_smoke_test(): | |
| print("--- Bắt đầu kiểm thử MiniDiffusionPipeline ---") | |
| if not torch.cuda.is_available(): | |
| print("CẢNH BÁO: Không tìm thấy CUDA. Chạy trên CPU (sẽ chậm).") | |
| device = "cpu" | |
| else: | |
| device = "cuda" | |
| # --- Tải mặc định (dùng VAE của 1.5) --- | |
| print("\n--- Tải mặc định ---") | |
| pipeline_1 = MiniDiffusionPipeline( | |
| base_model_id="runwayml/stable-diffusion-v1-5", | |
| device=device | |
| ) | |
| # --- Tải VAE-MSE --- | |
| print("\n--- Tải VAE-MSE tùy chỉnh ---") | |
| pipeline_2 = MiniDiffusionPipeline( | |
| base_model_id="runwayml/stable-diffusion-v1-5", | |
| vae_model_id="stabilityai/sd-vae-ft-mse", | |
| device=device | |
| ) | |
| # --- Ghi đè config --- | |
| print("\n--- Ghi đè config (UNet siêu nhỏ) ---") | |
| tiny_config = { | |
| "unet_block_out_channels": (128, 256, 512), | |
| "lr": 5e-5 | |
| } | |
| pipeline_3 = MiniDiffusionPipeline( | |
| base_model_id="runwayml/stable-diffusion-v1-5", | |
| device=device, | |
| config_overrides=tiny_config | |
| ) | |
| print("\n--- Kiểm thử thành công ---") | |
| print(f"Config LR của Pipeline 1: {pipeline_1.config['lr']}") | |
| print(f"Config LR của Pipeline 3: {pipeline_3.config['lr']}") | |
| if __name__ == "__main__": | |
| _run_smoke_test() |