# sagan_model.py import torch import torch.nn as nn from torch.nn.utils import spectral_norm # ------------------------- # Self-Attention Module # ------------------------- class Self_Attn(nn.Module): def __init__(self, in_dim): super().__init__() self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1) self.key_conv = nn.Conv2d(in_dim, in_dim // 8, 1) self.value_conv = nn.Conv2d(in_dim, in_dim, 1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, x): B, C, W, H = x.size() proj_q = self.query_conv(x).view(B, -1, W*H).permute(0,2,1) proj_k = self.key_conv(x).view(B, -1, W*H) energy = torch.bmm(proj_q, proj_k) # B×(WH)×(WH) attention = self.softmax(energy) proj_v = self.value_conv(x).view(B, -1, W*H) out = torch.bmm(proj_v, attention.permute(0,2,1)) out = out.view(B, C, W, H) return self.gamma * out + x # ------------------------- # Generator & Discriminator # ------------------------- class Generator(nn.Module): def __init__(self, z_dim=128, img_channels=3, base_channels=64): super().__init__() self.net = nn.Sequential( spectral_norm(nn.ConvTranspose2d(z_dim, base_channels*8, 4, 1, 0)), nn.BatchNorm2d(base_channels*8), nn.ReLU(True), spectral_norm(nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, 2, 1)), nn.BatchNorm2d(base_channels*4), nn.ReLU(True), # insert self‐attention at 32×32 Self_Attn(base_channels*4), spectral_norm(nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, 2, 1)), nn.BatchNorm2d(base_channels*2), nn.ReLU(True), spectral_norm(nn.ConvTranspose2d(base_channels*2, base_channels, 4, 2, 1)), nn.BatchNorm2d(base_channels), nn.ReLU(True), spectral_norm(nn.ConvTranspose2d(base_channels, img_channels, 4, 2, 1)), nn.Tanh() ) def forward(self, z): # Expect z shape: (B, z_dim, 1, 1) return self.net(z) class Discriminator(nn.Module): def __init__(self, img_channels=3, base_channels=64): super().__init__() self.net = nn.Sequential( spectral_norm(nn.Conv2d(img_channels, base_channels, 4, 2, 1)), nn.LeakyReLU(0.1, True), spectral_norm(nn.Conv2d(base_channels, base_channels*2, 4, 2, 1)), nn.LeakyReLU(0.1, True), # self‐attention at 32×32 Self_Attn(base_channels*2), spectral_norm(nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1)), nn.LeakyReLU(0.1, True), spectral_norm(nn.Conv2d(base_channels*4, 1, 4, 1, 0)) ) def forward(self, x): return self.net(x).view(-1) # ------------------------- # High-Level Wrapper # ------------------------- class SAGANModel(nn.Module): def __init__(self, z_dim=128, img_channels=3, base_channels=64): super().__init__() self.gen = Generator(z_dim, img_channels, base_channels) self.dis = Discriminator(img_channels, base_channels) def forward(self, z): # Only generator’s forward is typically used during inference return self.gen(z)