Muinez commited on
Commit
ffb9865
·
verified ·
1 Parent(s): 2df3e13

Upload train_pixel.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_pixel.py +187 -0
train_pixel.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import torch
3
+ import json
4
+ import os
5
+ from PIL import Image
6
+ from torchvision.transforms import v2
7
+ from torch.utils.data import DataLoader
8
+ import torch.nn.functional as F
9
+ from tqdm import tqdm
10
+ import torch.distributions as dist
11
+
12
+ def load_state_dict_safely(model, state_dict):
13
+ model_state = model.state_dict()
14
+ matched_keys = []
15
+ skipped_keys = []
16
+
17
+ for key, tensor in state_dict.items():
18
+ #if("encoder_proj" in key):
19
+ # continue
20
+ if key not in model_state:
21
+ skipped_keys.append(f"'{key}' (отсутствует в модели)")
22
+ continue
23
+
24
+ if tensor.shape != model_state[key].shape:
25
+ skipped_keys.append(f"'{key}' (форма {tensor.shape} != {model_state[key].shape})")
26
+ continue
27
+
28
+ model_state[key] = tensor
29
+ matched_keys.append(key)
30
+
31
+ model.load_state_dict(model_state)
32
+
33
+ return matched_keys, skipped_keys
34
+
35
+ def generate_skewed_tensor(shape, loc=-0.3, scale=1.0, device='cpu'):
36
+ base_distribution = dist.Normal(
37
+ torch.full(shape, loc, device=device, dtype=torch.bfloat16),
38
+ torch.full(shape, scale, device=device, dtype=torch.bfloat16)
39
+ )
40
+
41
+ logit_normal_distribution = dist.TransformedDistribution(
42
+ base_distribution, [dist.transforms.SigmoidTransform()]
43
+ )
44
+
45
+ return logit_normal_distribution.sample()
46
+ from tqdm import tqdm
47
+
48
+ def sample_images(vae, image, t = 0.5, num_inference_steps=50, cond=None):
49
+ torch.cuda.empty_cache()
50
+ timesteps = torch.linspace(0, 1, num_inference_steps, device='cuda', dtype=torch.bfloat16)
51
+
52
+ x = (1 - t) * torch.randn_like(image) + t * image
53
+ for i in tqdm(range(0, num_inference_steps-1)):
54
+ t_cur = timesteps[i].unsqueeze(0)
55
+ t_next = timesteps[i+1]
56
+ dt = t_next - t_cur
57
+
58
+ flow = vae(x,cond)
59
+ flow = (flow - x) / (1-t_cur)
60
+
61
+ x = x + flow * dt.to('cuda')
62
+
63
+ return x
64
+
65
+ from stae_pixel import StupidAE
66
+ from diffusers import AutoencoderKL
67
+ from transformers import AutoModel
68
+ os.environ['HF_HOME'] = '/home/muinez/hf_home'
69
+ siglip = AutoModel.from_pretrained("google/siglip2-base-patch32-256", trust_remote_code=True).bfloat16().cuda()
70
+ siglip.text_model = None
71
+ torch.cuda.empty_cache()
72
+ vae = StupidAE().cuda()
73
+
74
+ params = list(vae.parameters())
75
+
76
+ from muon import SingleDeviceMuonWithAuxAdam
77
+ hidden_weights = [p for p in params if p.ndim >= 2]
78
+ hidden_gains_biases = [p for p in params if p.ndim < 2]
79
+ param_groups = [
80
+ dict(params=hidden_weights, use_muon=True,
81
+ lr=5e-4, weight_decay=0),
82
+ dict(params=hidden_gains_biases, use_muon=False,
83
+ lr=3e-4, betas=(0.9, 0.95), weight_decay=0),
84
+ ]
85
+ optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
86
+ from snooc import SnooC
87
+ optimizer = SnooC(optimizer)
88
+
89
+ from torchvision.io import decode_image
90
+ import webdataset as wds
91
+ def decode_image_data(key, value):
92
+ if key.endswith((".jpg", ".jpeg", ".webp")):
93
+ try:
94
+ return decode_image(torch.tensor(list(value), dtype=torch.uint8), mode="RGB")
95
+ except Exception:
96
+ return None
97
+ return None
98
+
99
+ image_transforms = v2.Compose([
100
+ v2.ToDtype(torch.float32, scale=True),
101
+ v2.Resize((128, 128)),
102
+ v2.Normalize([0.5], [0.5]),
103
+ #v2.RandomHorizontalFlip(0.5),
104
+ #transforms.RandomVerticalFlip(0.5),
105
+ ])
106
+
107
+ def preprocess(sample):
108
+ image_key = 'jpg' if 'jpg' in sample else 'webp' if 'webp' in sample else None
109
+
110
+ if image_key:
111
+ sample[image_key] = image_transforms(sample[image_key])
112
+ sample['jpg'] = sample.pop(image_key)
113
+ return sample
114
+ batch_size = 512
115
+ num_workers = 32
116
+
117
+ urls = [
118
+ f"https://huggingface.co/datasets/Muinez/sankaku-webp-256shortest-edge/resolve/main/{i:04d}.tar"
119
+ for i in range(1000)
120
+ ]
121
+
122
+ dataset = wds.WebDataset(urls, handler=wds.warn_and_continue, shardshuffle=100000) \
123
+ .shuffle(2000) \
124
+ .decode(decode_image_data) \
125
+ .map(preprocess) \
126
+ .to_tuple("jpg")#.batched(batch_size)
127
+
128
+ from torch.utils.tensorboard import SummaryWriter
129
+ import datetime
130
+ logger = SummaryWriter(f'./logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')
131
+ load_state_dict_safely(vae, torch.load('pixel_flow_ae.pt'))
132
+
133
+ step = 0
134
+ while(True):
135
+ dataloader = DataLoader(
136
+ dataset,
137
+ num_workers=num_workers,
138
+ batch_size=batch_size,
139
+ prefetch_factor=16, persistent_workers=True,
140
+ drop_last=True
141
+ )
142
+ bar = tqdm(dataloader)
143
+ for data, in bar:
144
+ image = data.cuda().bfloat16()
145
+
146
+ # with torch.no_grad(), torch.amp.autocast('cuda', torch.bfloat16):
147
+ # last_hidden_state = siglip.vision_model(image, output_hidden_states=True).last_hidden_state
148
+
149
+
150
+ with torch.amp.autocast('cuda', torch.bfloat16):
151
+ device = image.device
152
+ cond = vae.encode(image)
153
+
154
+ t = generate_skewed_tensor((image.shape[0],1,1,1), device=device).to(torch.bfloat16)
155
+
156
+ x0 = torch.randn_like(image)
157
+ t_clamped = (1 - t).clamp(0.05, 1)
158
+
159
+ xt = (1 - t) * x0 + t * image
160
+ pred = vae(xt, cond)
161
+ velocity = (xt - pred) / t_clamped
162
+ target = (xt - image) / t_clamped
163
+
164
+ loss = torch.nn.functional.mse_loss(velocity.float(), target.float())
165
+
166
+ loss.backward()
167
+ grad_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
168
+ optimizer.step()
169
+ optimizer.zero_grad()
170
+ if(step % 1000 == 0):
171
+ torch.save(vae.state_dict(), 'pixel_flow_ae.pt')
172
+
173
+ bar.set_description(f'Step: {step}, Loss: {loss.item()}, Grad norm: {grad_norm}')
174
+
175
+ logger.add_scalar(f'Loss', loss, step)
176
+ if(step % 50 == 0):
177
+ with torch.amp.autocast('cuda', torch.bfloat16):
178
+ decoded = sample_images(vae, image[:4], t=0.0, cond=cond[:4])
179
+
180
+ for i in range(4):
181
+ logger.add_image(f'Decoded/{i}', decoded[i].cpu() * 0.5 + 0.5, step)
182
+ logger.add_image(f'Real/{i}', image[i].cpu() * 0.5 + 0.5, step)
183
+ torch.cuda.empty_cache()
184
+
185
+ logger.flush()
186
+
187
+ step += 1