File size: 6,357 Bytes
4cfe4fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
from diffusers.models import UNet2DConditionModel, AutoencoderKL
from diffusers.schedulers import DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from typing import Dict, Any, Optional

class MiniDiffusionPipeline:

    
    # config mặc định
    DEFAULT_CONFIG: Dict[str, Any] = {
        "beta_schedule": "scaled_linear",
        "beta_start": 0.00085,
        "beta_end": 0.0120,
        "num_train_timesteps": 1000,
        "prediction_type": "epsilon",
        "variance_type": "fixed_small",
        "clip_sample": False,
        "rescale_betas_zero_snr": False,
        "timestep_spacing": "leading",
        "lr": 1e-4,
        "optimizer": "AdamW",
        "scheduler": "cosine",
        "ema_decay": 0.9999,
        "latent_scale": 0.18215,
        "text_embed_dim": 768,
        "latent_channels": 4,
        "latent_downscale_factor": 8,
        
        # --- Cấu hình kiến trúc UNet-mini ---
        "image_size": 128, 
        "unet_block_out_channels": (256, 512, 1024),
        "unet_layers_per_block": 1,
        "unet_down_block_types": (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
        "unet_up_block_types": (
            "UpBlock2D",
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
        ),
        "unet_mid_block_type": "UNetMidBlock2DCrossAttn",
        "unet_attention_head_dim": 8,
    }

    def __init__(
        self,
        base_model_id: str = "stabilityai/stable-diffusion-v1-5",
        vae_model_id: Optional[str] = None,
        device: str = "cpu",
        config_overrides: Optional[Dict[str, Any]] = None
    ):
        self.device = torch.device(device)
        
        self.config = {**self.DEFAULT_CONFIG, **(config_overrides or {})}

        print(f"Đang tải Tokenizer và Text Encoder (đã đóng băng) từ {base_model_id}...")
        self.tokenizer = self._load_tokenizer(base_model_id)
        self.text_encoder = self._load_text_encoder(base_model_id)

        _vae_id = vae_model_id or base_model_id
        _vae_subfolder = "vae" if vae_model_id is None else None
        print(f"Đang tải VAE (để fine-tune) từ {_vae_id}...")
        self.vae = self._load_vae(_vae_id, _vae_subfolder)
        
        print("Khởi tạo UNet-mini (với trọng số ngẫu nhiên)...")
        self.unet = self._load_mini_unet()

        print("Khởi tạo Noise Scheduler...")
        self.noise_scheduler = self._load_noise_scheduler()
        
        print("\n--- MiniDiffusionPipeline đã sẵn sàng! ---")
        self.print_model_stats()

    def _load_tokenizer(self, model_id: str) -> CLIPTokenizer:
        return CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")

    def _load_text_encoder(self, model_id: str) -> CLIPTextModel:
        model = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
        model.to(self.device)
        model.requires_grad_(False)
        return model

    def _load_vae(self, model_id: str, subfolder: Optional[str]) -> AutoencoderKL:
        if subfolder:
            model = AutoencoderKL.from_pretrained(model_id, subfolder=subfolder)
        else:
            model = AutoencoderKL.from_pretrained(model_id)
        model.to(self.device)
        return model

    def _load_mini_unet(self) -> UNet2DConditionModel:

        latent_size = self.config["image_size"] // self.config["latent_downscale_factor"]
        
        unet_config = {
            "sample_size": latent_size,
            "in_channels": self.config["latent_channels"],
            "out_channels": self.config["latent_channels"],
            "block_out_channels": self.config["unet_block_out_channels"],
            "layers_per_block": self.config["unet_layers_per_block"],
            "down_block_types": self.config["unet_down_block_types"],
            "up_block_types": self.config["unet_up_block_types"],
            "mid_block_type": self.config["unet_mid_block_type"],
            "cross_attention_dim": self.config["text_embed_dim"],
            "attention_head_dim": self.config["unet_attention_head_dim"],
        }
        
        model = UNet2DConditionModel(**unet_config)
        model.to(self.device)
        return model

    def _load_noise_scheduler(self) -> DDPMScheduler:
        return DDPMScheduler.from_config(self.config)

    def print_model_stats(self):
        unet_params = sum(p.numel() for p in self.unet.parameters() if p.requires_grad)
        vae_params = sum(p.numel() for p in self.vae.parameters() if p.requires_grad)
        print(f"  UNet-mini (để train): {unet_params / 1_000_000:.2f} triệu tham số")
        print(f"  VAE (để fine-tune): {vae_params / 1_000_000:.2f} triệu tham số")
        
    def get_trainable_parameters(self) -> Dict[str, Any]:
        return {
            "unet": self.unet.parameters(),
            "vae": self.vae.parameters()
        }


# --- KHỐI KIỂM THỬ (SMOKE TEST) ---
def _run_smoke_test():
    print("--- Bắt đầu kiểm thử MiniDiffusionPipeline ---")
    
    if not torch.cuda.is_available():
        print("CẢNH BÁO: Không tìm thấy CUDA. Chạy trên CPU (sẽ chậm).")
        device = "cpu"
    else:
        device = "cuda"

    # --- Tải mặc định (dùng VAE của 1.5) ---
    print("\n--- Tải mặc định ---")
    pipeline_1 = MiniDiffusionPipeline(
        base_model_id="runwayml/stable-diffusion-v1-5",
        device=device
    )
    
    # --- Tải VAE-MSE ---
    print("\n--- Tải VAE-MSE tùy chỉnh ---")
    pipeline_2 = MiniDiffusionPipeline(
        base_model_id="runwayml/stable-diffusion-v1-5", 
        vae_model_id="stabilityai/sd-vae-ft-mse",      
        device=device
    )

    # ---  Ghi đè config ---
    print("\n--- Ghi đè config (UNet siêu nhỏ) ---")
    tiny_config = {
        "unet_block_out_channels": (128, 256, 512), 
        "lr": 5e-5 
    }
    pipeline_3 = MiniDiffusionPipeline(
        base_model_id="runwayml/stable-diffusion-v1-5",
        device=device,
        config_overrides=tiny_config
    )
    
    print("\n--- Kiểm thử thành công ---")
    print(f"Config LR của Pipeline 1: {pipeline_1.config['lr']}") 
    print(f"Config LR của Pipeline 3: {pipeline_3.config['lr']}") 


if __name__ == "__main__":
    _run_smoke_test()