| | import math |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as thf |
| | import pytorch_lightning as pl |
| | from ldm.util import instantiate_from_config |
| | import einops |
| | import kornia |
| | import numpy as np |
| | import torchvision |
| | from contextlib import contextmanager |
| | from ldm.modules.ema import LitEma |
| |
|
| |
|
| | class FlAE(pl.LightningModule): |
| | def __init__(self, |
| | cover_key, |
| | secret_key, |
| | secret_len, |
| | resolution, |
| | secret_encoder_config, |
| | secret_decoder_config, |
| | loss_config, |
| | noise_config='__none__', |
| | ckpt_path="__none__", |
| | use_ema=False |
| | ): |
| | super().__init__() |
| | self.cover_key = cover_key |
| | self.secret_key = secret_key |
| | secret_encoder_config.params.secret_len = secret_len |
| | secret_decoder_config.params.secret_len = secret_len |
| | secret_encoder_config.params.resolution = resolution |
| | secret_decoder_config.params.resolution = 224 |
| | self.encoder = instantiate_from_config(secret_encoder_config) |
| | self.decoder = instantiate_from_config(secret_decoder_config) |
| | self.loss_layer = instantiate_from_config(loss_config) |
| | if noise_config != '__none__': |
| | print('Using noise') |
| | self.noise = instantiate_from_config(noise_config) |
| |
|
| | self.use_ema = use_ema |
| | if self.use_ema: |
| | print('Using EMA') |
| | self.encoder_ema = LitEma(self.encoder) |
| | self.decoder_ema = LitEma(self.decoder) |
| | print(f"Keeping EMAs of {len(list(self.encoder_ema.buffers()) + list(self.decoder_ema.buffers()))}.") |
| |
|
| | if ckpt_path != "__none__": |
| | self.init_from_ckpt(ckpt_path, ignore_keys=[]) |
| | |
| | |
| | self.fixed_img = None |
| | self.fixed_secret = None |
| | self.register_buffer("fixed_input", torch.tensor(True)) |
| | self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") |
| | |
| | def init_from_ckpt(self, path, ignore_keys=list()): |
| | sd = torch.load(path, map_location="cpu")["state_dict"] |
| | keys = list(sd.keys()) |
| | for k in keys: |
| | for ik in ignore_keys: |
| | if k.startswith(ik): |
| | print("Deleting key {} from state_dict.".format(k)) |
| | del sd[k] |
| | self.load_state_dict(sd, strict=False) |
| | print(f"Restored from {path}") |
| | |
| | @contextmanager |
| | def ema_scope(self, context=None): |
| | if self.use_ema: |
| | self.encoder_ema.store(self.encoder.parameters()) |
| | self.decoder_ema.store(self.decoder.parameters()) |
| | self.encoder_ema.copy_to(self.encoder) |
| | self.decoder_ema.copy_to(self.decoder) |
| | if context is not None: |
| | print(f"{context}: Switched to EMA weights") |
| | try: |
| | yield None |
| | finally: |
| | if self.use_ema: |
| | self.encoder_ema.restore(self.encoder.parameters()) |
| | self.decoder_ema.restore(self.decoder.parameters()) |
| | if context is not None: |
| | print(f"{context}: Restored training weights") |
| |
|
| | def on_train_batch_end(self, *args, **kwargs): |
| | if self.use_ema: |
| | self.encoder_ema(self.encoder) |
| | self.decoder_ema(self.decoder) |
| | |
| | @torch.no_grad() |
| | def get_input(self, batch, bs=None): |
| | image = batch[self.cover_key] |
| | secret = batch[self.secret_key] |
| | if bs is not None: |
| | image = image[:bs] |
| | secret = secret[:bs] |
| | else: |
| | bs = image.shape[0] |
| | |
| | image = einops.rearrange(image, "b h w c -> b c h w").contiguous() |
| | |
| | |
| | |
| | if self.fixed_input: |
| | if self.fixed_img is None: |
| | print('[TRAINING] Warmup - using fixed input image for now!') |
| | self.fixed_img = image.detach().clone()[:bs] |
| | self.fixed_secret = secret.detach().clone()[:bs] |
| | image = self.fixed_img |
| | new_bs = min(secret.shape[0], image.shape[0]) |
| | image, secret = image[:new_bs], secret[:new_bs] |
| | |
| | out = [image, secret] |
| | return out |
| | |
| | def forward(self, cover, secret): |
| | |
| | enc_out = self.encoder(cover, secret) |
| | if self.encoder.return_residual: |
| | return cover + enc_out, enc_out |
| | else: |
| | return enc_out, enc_out - cover |
| |
|
| | def shared_step(self, batch): |
| | x, s = self.get_input(batch) |
| | stego, residual = self(x, s) |
| | if hasattr(self, "noise") and self.noise.is_activated(): |
| | stego_noised = self.noise(stego, self.global_step, p=0.9) |
| | else: |
| | stego_noised = self.crop(stego) |
| | stego_noised = torch.clamp(stego_noised, -1, 1) |
| | spred = self.decoder(stego_noised) |
| |
|
| | loss, loss_dict = self.loss_layer(x, stego, None, s, spred, self.global_step) |
| | bit_acc = loss_dict["bit_acc"] |
| |
|
| | bit_acc_ = bit_acc.item() |
| |
|
| | if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated(): |
| | self.loss_layer.activate_ramp(self.global_step) |
| |
|
| | if (bit_acc_ > 0.95) and (not self.fixed_input): |
| | if hasattr(self, 'noise') and (not self.noise.is_activated()): |
| | self.noise.activate(self.global_step) |
| |
|
| | if (bit_acc_ > 0.9) and self.fixed_input: |
| | print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.') |
| | self.fixed_input = ~self.fixed_input |
| | return loss, loss_dict |
| | |
| | def training_step(self, batch, batch_idx): |
| | loss, loss_dict = self.shared_step(batch) |
| | loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} |
| | self.log_dict(loss_dict, prog_bar=True, |
| | logger=True, on_step=True, on_epoch=True) |
| | |
| | self.log("global_step", self.global_step, |
| | prog_bar=True, logger=True, on_step=True, on_epoch=False) |
| | |
| | |
| | |
| |
|
| | return loss |
| | |
| | @torch.no_grad() |
| | def validation_step(self, batch, batch_idx): |
| | _, loss_dict_no_ema = self.shared_step(batch) |
| | loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'} |
| | with self.ema_scope(): |
| | _, loss_dict_ema = self.shared_step(batch) |
| | loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} |
| | self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
| | self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) |
| | |
| | @torch.no_grad() |
| | def log_images(self, batch, fixed_input=False, **kwargs): |
| | log = dict() |
| | if fixed_input and self.fixed_img is not None: |
| | x, s = self.fixed_img, self.fixed_secret |
| | else: |
| | x, s = self.get_input(batch) |
| | stego, residual = self(x, s) |
| | if hasattr(self, 'noise') and self.noise.is_activated(): |
| | img_noise = self.noise(stego, self.global_step, p=1.0) |
| | log['noised'] = img_noise |
| | log['input'] = x |
| | log['stego'] = stego |
| | log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1 |
| | return log |
| | |
| | def configure_optimizers(self): |
| | lr = self.learning_rate |
| | params = list(self.encoder.parameters()) + list(self.decoder.parameters()) |
| | optimizer = torch.optim.AdamW(params, lr=lr) |
| | return optimizer |
| | |
| | |
| |
|
| |
|
| | class SecretEncoder(nn.Module): |
| | def __init__(self, resolution=256, secret_len=100, return_residual=False, act='tanh') -> None: |
| | super().__init__() |
| | self.secret_len = secret_len |
| | self.return_residual = return_residual |
| | self.act_fn = lambda x: torch.tanh(x) if act == 'tanh' else thf.sigmoid(x) * 2.0 -1.0 |
| | self.secret_dense = nn.Linear(secret_len, 16*16*3) |
| | log_resolution = int(math.log(resolution, 2)) |
| | assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}." |
| | self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))) |
| | self.conv1 = nn.Conv2d(2 * 3, 32, 3, 1, 1) |
| | self.conv2 = nn.Conv2d(32, 32, 3, 2, 1) |
| | self.conv3 = nn.Conv2d(32, 64, 3, 2, 1) |
| | self.conv4 = nn.Conv2d(64, 128, 3, 2, 1) |
| | self.conv5 = nn.Conv2d(128, 256, 3, 2, 1) |
| | self.pad6 = nn.ZeroPad2d((0, 1, 0, 1)) |
| | self.up6 = nn.Conv2d(256, 128, 2, 1) |
| | self.upsample6 = nn.Upsample(scale_factor=(2, 2)) |
| | self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1) |
| | self.pad7 = nn.ZeroPad2d((0, 1, 0, 1)) |
| | self.up7 = nn.Conv2d(128, 64, 2, 1) |
| | self.upsample7 = nn.Upsample(scale_factor=(2, 2)) |
| | self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1) |
| | self.pad8 = nn.ZeroPad2d((0, 1, 0, 1)) |
| | self.up8 = nn.Conv2d(64, 32, 2, 1) |
| | self.upsample8 = nn.Upsample(scale_factor=(2, 2)) |
| | self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1) |
| | self.pad9 = nn.ZeroPad2d((0, 1, 0, 1)) |
| | self.up9 = nn.Conv2d(32, 32, 2, 1) |
| | self.upsample9 = nn.Upsample(scale_factor=(2, 2)) |
| | self.conv9 = nn.Conv2d(32 + 32 + 2 * 3, 32, 3, 1, 1) |
| | self.conv10 = nn.Conv2d(32, 32, 3, 1, 1) |
| | self.residual = nn.Conv2d(32, 3, 1) |
| | |
| | def forward(self, image, secret): |
| | fingerprint = thf.relu(self.secret_dense(secret)) |
| | fingerprint = fingerprint.view((-1, 3, 16, 16)) |
| | fingerprint_enlarged = self.secret_upsample(fingerprint) |
| | |
| | inputs = torch.cat([fingerprint_enlarged, image], dim=1) |
| | |
| | |
| | |
| | conv1 = thf.relu(self.conv1(inputs)) |
| | conv2 = thf.relu(self.conv2(conv1)) |
| | conv3 = thf.relu(self.conv3(conv2)) |
| | conv4 = thf.relu(self.conv4(conv3)) |
| | conv5 = thf.relu(self.conv5(conv4)) |
| | up6 = thf.relu(self.up6(self.pad6(self.upsample6(conv5)))) |
| | merge6 = torch.cat([conv4, up6], dim=1) |
| | conv6 = thf.relu(self.conv6(merge6)) |
| | up7 = thf.relu(self.up7(self.pad7(self.upsample7(conv6)))) |
| | merge7 = torch.cat([conv3, up7], dim=1) |
| | conv7 = thf.relu(self.conv7(merge7)) |
| | up8 = thf.relu(self.up8(self.pad8(self.upsample8(conv7)))) |
| | merge8 = torch.cat([conv2, up8], dim=1) |
| | conv8 = thf.relu(self.conv8(merge8)) |
| | up9 = thf.relu(self.up9(self.pad9(self.upsample9(conv8)))) |
| | merge9 = torch.cat([conv1, up9, inputs], dim=1) |
| | conv9 = thf.relu(self.conv9(merge9)) |
| | conv10 = thf.relu(self.conv10(conv9)) |
| | residual = self.residual(conv10) |
| | residual = self.act_fn(residual) |
| | return residual |
| |
|
| |
|
| | class SecretEncoder1(nn.Module): |
| | def __init__(self, resolution=256, secret_len=100) -> None: |
| | pass |
| |
|
| | class SecretDecoder(nn.Module): |
| | def __init__(self, arch='resnet18', resolution=224, secret_len=100): |
| | super().__init__() |
| | self.resolution = resolution |
| | self.arch = arch |
| | if arch == 'resnet18': |
| | self.decoder = torchvision.models.resnet18(pretrained=True, progress=False) |
| | self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len) |
| | elif arch == 'resnet50': |
| | self.decoder = torchvision.models.resnet50(pretrained=True, progress=False) |
| | self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len) |
| | elif arch == 'simple': |
| | self.decoder = SimpleCNN(resolution, secret_len) |
| | else: |
| | raise ValueError('Unknown architecture') |
| | |
| | def forward(self, image): |
| | if self.arch in ['resnet50', 'resnet18'] and image.shape[-1] > self.resolution: |
| | image = thf.interpolate(image, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) |
| | x = self.decoder(image) |
| | return x |
| |
|
| |
|
| | class SimpleCNN(nn.Module): |
| | def __init__(self, resolution=224, secret_len=100): |
| | super().__init__() |
| | self.resolution = resolution |
| | self.IMAGE_CHANNELS = 3 |
| | self.decoder = nn.Sequential( |
| | nn.Conv2d(self.IMAGE_CHANNELS, 32, (3, 3), 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(32, 32, 3, 1, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(32, 64, 3, 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 64, 3, 1, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 64, 3, 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 128, 3, 2, 1), |
| | nn.ReLU(), |
| | nn.Conv2d(128, 128, (3, 3), 2, 1), |
| | nn.ReLU(), |
| | ) |
| | self.dense = nn.Sequential( |
| | nn.Linear(resolution * resolution * 128 // 32 // 32, 512), |
| | nn.ReLU(), |
| | nn.Linear(512, secret_len), |
| | ) |
| |
|
| | def forward(self, image): |
| | x = self.decoder(image) |
| | x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32) |
| | return self.dense(x) |