Muinez commited on
Commit
7a67c9d
·
verified ·
1 Parent(s): 72772a5

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +125 -0
train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import torch
3
+ import json
4
+ import os
5
+ from PIL import Image
6
+ from torchvision.transforms import v2
7
+ from torch.utils.data import DataLoader
8
+ import torch.nn.functional as F
9
+ from tqdm import tqdm
10
+
11
+ from stae import StupidAE
12
+ from diffusers import AutoencoderKL
13
+ from transformers import AutoModel
14
+ os.environ['HF_HOME'] = '/home/muinez/hf_home'
15
+ siglip = AutoModel.from_pretrained("google/siglip2-base-patch32-256", trust_remote_code=True).bfloat16().cuda()
16
+ siglip.text_model = None
17
+ torch.cuda.empty_cache()
18
+ vae = StupidAE().cuda()
19
+
20
+ params = list(vae.parameters())
21
+
22
+ from muon import SingleDeviceMuonWithAuxAdam
23
+ hidden_weights = [p for p in params if p.ndim >= 2]
24
+ hidden_gains_biases = [p for p in params if p.ndim < 2]
25
+ param_groups = [
26
+ dict(params=hidden_weights, use_muon=True,
27
+ lr=1e-4, weight_decay=1e-4),
28
+ dict(params=hidden_gains_biases, use_muon=False,
29
+ lr=3e-4, betas=(0.9, 0.95), weight_decay=1e-4),
30
+ ]
31
+ optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
32
+ from snooc import SnooC
33
+ optimizer = SnooC(optimizer)
34
+
35
+ from torchvision.io import decode_image
36
+ import webdataset as wds
37
+ def decode_image_data(key, value):
38
+ if key.endswith((".jpg", ".jpeg", ".webp")):
39
+ try:
40
+ return decode_image(torch.tensor(list(value), dtype=torch.uint8), mode="RGB")
41
+ except Exception:
42
+ return None
43
+ return None
44
+
45
+ image_transforms = v2.Compose([
46
+ v2.ToDtype(torch.float32, scale=True),
47
+ v2.Resize((256, 256)),
48
+ v2.Normalize([0.5], [0.5]),
49
+ #v2.RandomHorizontalFlip(0.5),
50
+ #transforms.RandomVerticalFlip(0.5),
51
+ ])
52
+
53
+ def preprocess(sample):
54
+ image_key = 'jpg' if 'jpg' in sample else 'webp' if 'webp' in sample else None
55
+
56
+ if image_key:
57
+ sample[image_key] = image_transforms(sample[image_key])
58
+ sample['jpg'] = sample.pop(image_key)
59
+ return sample
60
+ batch_size = 96
61
+ num_workers = 16
62
+
63
+ urls = [
64
+ f"https://huggingface.co/datasets/Muinez/sankaku-webp-256shortest-edge/resolve/main/{i:04d}.tar"
65
+ for i in range(1000)
66
+ ]
67
+
68
+ dataset = wds.WebDataset(urls, handler=wds.warn_and_continue, shardshuffle=100000) \
69
+ .shuffle(2000) \
70
+ .decode(decode_image_data) \
71
+ .map(preprocess) \
72
+ .to_tuple("jpg")#.batched(batch_size)
73
+
74
+ from torch.utils.tensorboard import SummaryWriter
75
+ import datetime
76
+ logger = SummaryWriter(f'./logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}')
77
+ vae.load_state_dict(torch.load('model_2.pt'))
78
+
79
+ step = 0
80
+ while(True):
81
+ dataloader = DataLoader(
82
+ dataset,
83
+ num_workers=num_workers,
84
+ batch_size=batch_size,
85
+ prefetch_factor=16, persistent_workers=True,
86
+ drop_last=True
87
+ )
88
+ bar = tqdm(dataloader)
89
+ for data, in bar:
90
+ image = data.cuda().bfloat16()
91
+
92
+ with torch.no_grad(), torch.amp.autocast('cuda', torch.bfloat16):
93
+ last_hidden_state = siglip.vision_model(image, output_hidden_states=True).last_hidden_state
94
+ std = last_hidden_state.std()
95
+ last_hidden_state = last_hidden_state / std
96
+ with torch.amp.autocast('cuda', torch.bfloat16):
97
+ latent = vae.encode(image)
98
+ decoded = vae.decode(latent)
99
+ semantic = vae.semantic_decoder(latent) / std
100
+ semantic = semantic.flatten(2).transpose(1,2)
101
+
102
+ pixel_loss = F.mse_loss(decoded.float(), image.float())
103
+ semantic_loss = F.mse_loss(semantic.float(), last_hidden_state.float())
104
+
105
+ loss = pixel_loss + semantic_loss
106
+
107
+ loss.backward()
108
+ grad_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
109
+ optimizer.step()
110
+ optimizer.zero_grad()
111
+ if(step % 1000 == 0):
112
+ torch.save(vae.state_dict(), 'model_2.pt')
113
+
114
+ bar.set_description(f'Step: {step}, Loss: {loss.item()}, Grad norm: {grad_norm}, Std: {latent.std()}')
115
+
116
+ logger.add_scalar(f'Pixel loss', pixel_loss, step)
117
+ logger.add_scalar(f'Semantic loss', semantic_loss, step)
118
+ if(step % 50 == 0):
119
+ for i in range(3):
120
+ logger.add_image(f'Decoded/{i}', decoded[i].cpu() * 0.5 + 0.5, step)
121
+ logger.add_image(f'Real/{i}', image[i].cpu() * 0.5 + 0.5, step)
122
+
123
+ logger.flush()
124
+
125
+ step += 1