File size: 5,460 Bytes
ffb9865
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import glob
import torch
import json
import os
from PIL import Image
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import torch.distributions as dist

def load_state_dict_safely(model, state_dict):
	model_state = model.state_dict()
	matched_keys = []
	skipped_keys = []
	
	for key, tensor in state_dict.items():
		#if("encoder_proj" in key):
		#	continue
		if key not in model_state:
			skipped_keys.append(f"'{key}' (отсутствует в модели)")
			continue
			
		if tensor.shape != model_state[key].shape:
			skipped_keys.append(f"'{key}' (форма {tensor.shape} != {model_state[key].shape})")
			continue
			
		model_state[key] = tensor
		matched_keys.append(key)
	
	model.load_state_dict(model_state)
	
	return matched_keys, skipped_keys

def generate_skewed_tensor(shape, loc=-0.3, scale=1.0, device='cpu'):
	base_distribution = dist.Normal(
		torch.full(shape, loc, device=device, dtype=torch.bfloat16), 
		torch.full(shape, scale, device=device, dtype=torch.bfloat16)
	)

	logit_normal_distribution = dist.TransformedDistribution(
		base_distribution, [dist.transforms.SigmoidTransform()]
	)

	return logit_normal_distribution.sample()
from tqdm import tqdm

def sample_images(vae, image, t = 0.5, num_inference_steps=50, cond=None):
	torch.cuda.empty_cache()
	timesteps = torch.linspace(0, 1, num_inference_steps, device='cuda', dtype=torch.bfloat16)
	
	x = (1 - t) * torch.randn_like(image) + t * image
	for i in tqdm(range(0, num_inference_steps-1)):
		t_cur = timesteps[i].unsqueeze(0)
		t_next = timesteps[i+1]
		dt = t_next - t_cur

		flow = vae(x,cond)
		flow = (flow - x) / (1-t_cur)

		x = x + flow * dt.to('cuda')

	return x

from stae_pixel import StupidAE
from diffusers import AutoencoderKL
from transformers import AutoModel
os.environ['HF_HOME'] = '/home/muinez/hf_home'
siglip = AutoModel.from_pretrained("google/siglip2-base-patch32-256", trust_remote_code=True).bfloat16().cuda()
siglip.text_model = None
torch.cuda.empty_cache()
vae  = StupidAE().cuda()

params = list(vae.parameters())

from muon import SingleDeviceMuonWithAuxAdam
hidden_weights = [p for p in params if p.ndim >= 2]
hidden_gains_biases = [p for p in params if p.ndim < 2]
param_groups = [
	dict(params=hidden_weights, use_muon=True,
		 lr=5e-4, weight_decay=0),
	dict(params=hidden_gains_biases, use_muon=False,
		 lr=3e-4, betas=(0.9, 0.95), weight_decay=0),
]
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
from snooc import SnooC
optimizer = SnooC(optimizer)

from torchvision.io import decode_image
import webdataset as wds
def decode_image_data(key, value):
	if key.endswith((".jpg", ".jpeg", ".webp")):
		try:
			return decode_image(torch.tensor(list(value), dtype=torch.uint8), mode="RGB")
		except Exception:
			return None
	return None

image_transforms = v2.Compose([
	v2.ToDtype(torch.float32, scale=True),
	v2.Resize((128, 128)),
	v2.Normalize([0.5], [0.5]),
	#v2.RandomHorizontalFlip(0.5),
	#transforms.RandomVerticalFlip(0.5),
])

def preprocess(sample):
	image_key = 'jpg' if 'jpg' in sample else 'webp' if 'webp' in sample else None
	
	if image_key:
		sample[image_key] = image_transforms(sample[image_key])
		sample['jpg'] = sample.pop(image_key)
	return sample
batch_size = 512
num_workers = 32

urls = [
	f"https://huggingface.co/datasets/Muinez/sankaku-webp-256shortest-edge/resolve/main/{i:04d}.tar"
	for i in range(1000)
]

dataset = wds.WebDataset(urls, handler=wds.warn_and_continue, shardshuffle=100000) \
	.shuffle(2000) \
	.decode(decode_image_data) \
	.map(preprocess) \
	.to_tuple("jpg")#.batched(batch_size)

from torch.utils.tensorboard import SummaryWriter
import datetime
logger = SummaryWriter(f'./logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')
load_state_dict_safely(vae, torch.load('pixel_flow_ae.pt'))

step = 0
while(True):
	dataloader = DataLoader(
		dataset,
		num_workers=num_workers,
		batch_size=batch_size,
		prefetch_factor=16, persistent_workers=True,
		drop_last=True
	)
	bar = tqdm(dataloader)
	for data, in bar:
		image = data.cuda().bfloat16()
		
		# with torch.no_grad(), torch.amp.autocast('cuda', torch.bfloat16):
		# 	last_hidden_state = siglip.vision_model(image, output_hidden_states=True).last_hidden_state
			

		with torch.amp.autocast('cuda', torch.bfloat16):
			device = image.device
			cond = vae.encode(image)

			t = generate_skewed_tensor((image.shape[0],1,1,1), device=device).to(torch.bfloat16)

			x0 = torch.randn_like(image)
			t_clamped = (1 - t).clamp(0.05, 1)
			
			xt = (1 - t) * x0 + t * image
			pred = vae(xt, cond)
			velocity = (xt - pred) / t_clamped
			target = (xt - image) / t_clamped

		loss = torch.nn.functional.mse_loss(velocity.float(), target.float())

		loss.backward()
		grad_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
		optimizer.step()
		optimizer.zero_grad()
		if(step % 1000 == 0):
			torch.save(vae.state_dict(), 'pixel_flow_ae.pt')

		bar.set_description(f'Step: {step}, Loss: {loss.item()}, Grad norm: {grad_norm}')
		
		logger.add_scalar(f'Loss', loss, step)
		if(step % 50 == 0):
			with torch.amp.autocast('cuda', torch.bfloat16):
				decoded = sample_images(vae, image[:4], t=0.0, cond=cond[:4])
			
			for i in range(4):
				logger.add_image(f'Decoded/{i}', decoded[i].cpu() * 0.5 + 0.5, step)
				logger.add_image(f'Real/{i}', image[i].cpu() * 0.5 + 0.5, step)
			torch.cuda.empty_cache()
		
		logger.flush()

		step += 1