SAE / attacks /Gaker /Generator /Generator.py
Ttius's picture
Upload 192 files
998bb30 verified
import math
from telnetlib import PRAGMA_HEARTBEAT
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from einops import rearrange
class TargetEmbedding(nn.Module):
def __init__(self, dim, feature_channel_num):
super().__init__()
self.linear = nn.Sequential(
nn.Linear(feature_channel_num, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
def forward(self,mix):
mix_linear = self.linear(mix)
return mix_linear
class DownSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.c1 = nn.Sequential(nn.GELU(),
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1, padding_mode='reflect'))
self.c2 = nn.Sequential(nn.GELU(),
nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2, padding_mode='reflect'))
def forward(self, x, target_emb):
x = self.c1(x)+self.c2(x)
return x
class UpSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.t = nn.Sequential(nn.GELU(), nn.GroupNorm(16, in_ch),
nn.ConvTranspose2d(in_ch, in_ch, 5, 2, 2, 1))
def forward(self, x, target_emb):
_, _, H, W = x.shape
x = self.t(x)
return x
class AttnBlock(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.norm = nn.GroupNorm(16,in_ch)
self.to_kv=nn.Conv2d(in_ch,in_ch*2,1)
self.out = nn.Sequential(nn.GroupNorm(16,in_ch),
nn.GELU(),
nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0))
def forward(self, x):
B, C, H, W = x.shape
x = self.norm(x)
q_scale = int(C) ** (-0.5)
kv = self.to_kv(x).chunk(2,dim=1)
k,v = map(lambda t: rearrange(t, 'b c x y -> b c (x y)'),kv)
q = F.softmax(k,dim=-2)
k = F.softmax(k,dim=-1)
q = q*q_scale
context = torch.einsum('b d n, b e n -> b d e',k,v)
assert list(context.shape) == [B, C, C]
out = torch.einsum('b d e, b d n -> b e n',context,q)
assert list(out.shape) == [B, C, H*W]
out = rearrange(out,'b c (i j) -> b c i j', i=H ,j=W)
out = self.out(out)
return x + out
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, attn=True):
super().__init__()
self.block1 = nn.Sequential(
nn.GELU(),
nn.GroupNorm(16, in_ch),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1, padding_mode='reflect'),
)
self.target_proj = nn.Sequential(
nn.GELU(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.GELU(),
nn.GroupNorm(16, out_ch),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, padding_mode='reflect'),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
def forward(self, x , target):
h = self.block1(x)
h += self.target_proj(target)[:, :, None, None]
h = self.block2(h)
B,C,H,W = h.size()
layer_norm = nn.LayerNorm([C,H,W])
h = layer_norm.to(h.device)(h)
h = h + self.shortcut(x)
h = self.attn(h)
return h
class GCT(nn.Module):
def __init__(self, num_channels, tdim, epsilon=1e-5, mode='l2', after_relu=False):
super(GCT, self).__init__()
self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))
self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
self.epsilon = epsilon
self.mode = mode
self.after_relu = after_relu
def forward(self, x):
b, c, h, w = x.shape
if self.mode == 'l2':
embedding = (x.pow(2).sum((2, 3), keepdim=True) +
self.epsilon).pow(0.5) * self.alpha
norm = self.gamma / \
(embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
elif self.mode == 'l1':
if not self.after_relu:
_x = torch.abs(x)
else:
_x = x
embedding = _x.sum((2, 3), keepdim=True) * self.alpha
norm = self.gamma / \
(torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
gate = 1. + torch.tanh(embedding * norm + self.beta)
return x * gate
class Generator(nn.Module):
def __init__(self, num_target,feature_channel_num, ch, ch_mult, num_res_blocks,inception=False):
super().__init__()
tdim = ch * 4
self.target_embedding = TargetEmbedding(tdim,feature_channel_num)
self.head = nn.Sequential(
nn.Conv2d(3, ch, kernel_size=5, padding=2, padding_mode='reflect'),
nn.GELU(),
nn.GroupNorm(16, ch)
)
self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0)
self.inception = inception
self.downblocks = nn.ModuleList()
chs = [ch]
now_ch = ch
feature_channel = feature_channel_num
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, attn=True),
ResBlock(now_ch, now_ch, tdim, attn=False),
])
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, attn=False))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.gct = GCT(num_channels=now_ch,tdim=tdim)
self.tail = nn.Sequential(
nn.GroupNorm(16, now_ch),
nn.GELU(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1, padding_mode='reflect'),
)
def weight_init(self):
for m in self.modules():
if isinstance(m,torch.nn.Conv2d):
init.xavier_uniform_(m.weight.data)
if m.bias is not None:
init.constant_(m.bias.data,0.1)
elif isinstance(m,torch.nn.GroupNorm):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m,torch.nn.Linear):
m.weight.data.normal_(0,0.01)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x,mix):
targetemb = self.target_embedding(mix)
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, targetemb)
hs.append(h)
for layer in self.middleblocks:
h = layer(h, targetemb)
for layer in self.upblocks:
if isinstance(layer, ResBlock):
hs_pop = hs.pop()
if hs_pop.size(2) != h.size(2) and self.inception:
h = self.crop(h)
h = torch.cat([h, hs_pop], dim=1)
h = layer(h, targetemb)
h = self.gct(h)
h = self.tail(h)
assert len(hs) == 0
return (torch.tanh(h) + 1) / 2