File size: 5,460 Bytes
ffb9865 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import glob
import torch
import json
import os
from PIL import Image
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import torch.distributions as dist
def load_state_dict_safely(model, state_dict):
model_state = model.state_dict()
matched_keys = []
skipped_keys = []
for key, tensor in state_dict.items():
#if("encoder_proj" in key):
# continue
if key not in model_state:
skipped_keys.append(f"'{key}' (отсутствует в модели)")
continue
if tensor.shape != model_state[key].shape:
skipped_keys.append(f"'{key}' (форма {tensor.shape} != {model_state[key].shape})")
continue
model_state[key] = tensor
matched_keys.append(key)
model.load_state_dict(model_state)
return matched_keys, skipped_keys
def generate_skewed_tensor(shape, loc=-0.3, scale=1.0, device='cpu'):
base_distribution = dist.Normal(
torch.full(shape, loc, device=device, dtype=torch.bfloat16),
torch.full(shape, scale, device=device, dtype=torch.bfloat16)
)
logit_normal_distribution = dist.TransformedDistribution(
base_distribution, [dist.transforms.SigmoidTransform()]
)
return logit_normal_distribution.sample()
from tqdm import tqdm
def sample_images(vae, image, t = 0.5, num_inference_steps=50, cond=None):
torch.cuda.empty_cache()
timesteps = torch.linspace(0, 1, num_inference_steps, device='cuda', dtype=torch.bfloat16)
x = (1 - t) * torch.randn_like(image) + t * image
for i in tqdm(range(0, num_inference_steps-1)):
t_cur = timesteps[i].unsqueeze(0)
t_next = timesteps[i+1]
dt = t_next - t_cur
flow = vae(x,cond)
flow = (flow - x) / (1-t_cur)
x = x + flow * dt.to('cuda')
return x
from stae_pixel import StupidAE
from diffusers import AutoencoderKL
from transformers import AutoModel
os.environ['HF_HOME'] = '/home/muinez/hf_home'
siglip = AutoModel.from_pretrained("google/siglip2-base-patch32-256", trust_remote_code=True).bfloat16().cuda()
siglip.text_model = None
torch.cuda.empty_cache()
vae = StupidAE().cuda()
params = list(vae.parameters())
from muon import SingleDeviceMuonWithAuxAdam
hidden_weights = [p for p in params if p.ndim >= 2]
hidden_gains_biases = [p for p in params if p.ndim < 2]
param_groups = [
dict(params=hidden_weights, use_muon=True,
lr=5e-4, weight_decay=0),
dict(params=hidden_gains_biases, use_muon=False,
lr=3e-4, betas=(0.9, 0.95), weight_decay=0),
]
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
from snooc import SnooC
optimizer = SnooC(optimizer)
from torchvision.io import decode_image
import webdataset as wds
def decode_image_data(key, value):
if key.endswith((".jpg", ".jpeg", ".webp")):
try:
return decode_image(torch.tensor(list(value), dtype=torch.uint8), mode="RGB")
except Exception:
return None
return None
image_transforms = v2.Compose([
v2.ToDtype(torch.float32, scale=True),
v2.Resize((128, 128)),
v2.Normalize([0.5], [0.5]),
#v2.RandomHorizontalFlip(0.5),
#transforms.RandomVerticalFlip(0.5),
])
def preprocess(sample):
image_key = 'jpg' if 'jpg' in sample else 'webp' if 'webp' in sample else None
if image_key:
sample[image_key] = image_transforms(sample[image_key])
sample['jpg'] = sample.pop(image_key)
return sample
batch_size = 512
num_workers = 32
urls = [
f"https://huggingface.co/datasets/Muinez/sankaku-webp-256shortest-edge/resolve/main/{i:04d}.tar"
for i in range(1000)
]
dataset = wds.WebDataset(urls, handler=wds.warn_and_continue, shardshuffle=100000) \
.shuffle(2000) \
.decode(decode_image_data) \
.map(preprocess) \
.to_tuple("jpg")#.batched(batch_size)
from torch.utils.tensorboard import SummaryWriter
import datetime
logger = SummaryWriter(f'./logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')
load_state_dict_safely(vae, torch.load('pixel_flow_ae.pt'))
step = 0
while(True):
dataloader = DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
prefetch_factor=16, persistent_workers=True,
drop_last=True
)
bar = tqdm(dataloader)
for data, in bar:
image = data.cuda().bfloat16()
# with torch.no_grad(), torch.amp.autocast('cuda', torch.bfloat16):
# last_hidden_state = siglip.vision_model(image, output_hidden_states=True).last_hidden_state
with torch.amp.autocast('cuda', torch.bfloat16):
device = image.device
cond = vae.encode(image)
t = generate_skewed_tensor((image.shape[0],1,1,1), device=device).to(torch.bfloat16)
x0 = torch.randn_like(image)
t_clamped = (1 - t).clamp(0.05, 1)
xt = (1 - t) * x0 + t * image
pred = vae(xt, cond)
velocity = (xt - pred) / t_clamped
target = (xt - image) / t_clamped
loss = torch.nn.functional.mse_loss(velocity.float(), target.float())
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
if(step % 1000 == 0):
torch.save(vae.state_dict(), 'pixel_flow_ae.pt')
bar.set_description(f'Step: {step}, Loss: {loss.item()}, Grad norm: {grad_norm}')
logger.add_scalar(f'Loss', loss, step)
if(step % 50 == 0):
with torch.amp.autocast('cuda', torch.bfloat16):
decoded = sample_images(vae, image[:4], t=0.0, cond=cond[:4])
for i in range(4):
logger.add_image(f'Decoded/{i}', decoded[i].cpu() * 0.5 + 0.5, step)
logger.add_image(f'Real/{i}', image[i].cpu() * 0.5 + 0.5, step)
torch.cuda.empty_cache()
logger.flush()
step += 1 |