Spaces:
Sleeping
Sleeping
| # sagan_model.py | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.utils import spectral_norm | |
| # ------------------------- | |
| # Self-Attention Module | |
| # ------------------------- | |
| class Self_Attn(nn.Module): | |
| def __init__(self, in_dim): | |
| super().__init__() | |
| self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1) | |
| self.key_conv = nn.Conv2d(in_dim, in_dim // 8, 1) | |
| self.value_conv = nn.Conv2d(in_dim, in_dim, 1) | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x): | |
| B, C, W, H = x.size() | |
| proj_q = self.query_conv(x).view(B, -1, W*H).permute(0,2,1) | |
| proj_k = self.key_conv(x).view(B, -1, W*H) | |
| energy = torch.bmm(proj_q, proj_k) # B×(WH)×(WH) | |
| attention = self.softmax(energy) | |
| proj_v = self.value_conv(x).view(B, -1, W*H) | |
| out = torch.bmm(proj_v, attention.permute(0,2,1)) | |
| out = out.view(B, C, W, H) | |
| return self.gamma * out + x | |
| # ------------------------- | |
| # Generator & Discriminator | |
| # ------------------------- | |
| class Generator(nn.Module): | |
| def __init__(self, z_dim=128, img_channels=3, base_channels=64): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| spectral_norm(nn.ConvTranspose2d(z_dim, base_channels*8, 4, 1, 0)), | |
| nn.BatchNorm2d(base_channels*8), | |
| nn.ReLU(True), | |
| spectral_norm(nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, 2, 1)), | |
| nn.BatchNorm2d(base_channels*4), | |
| nn.ReLU(True), | |
| # insert self‐attention at 32×32 | |
| Self_Attn(base_channels*4), | |
| spectral_norm(nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, 2, 1)), | |
| nn.BatchNorm2d(base_channels*2), | |
| nn.ReLU(True), | |
| spectral_norm(nn.ConvTranspose2d(base_channels*2, base_channels, 4, 2, 1)), | |
| nn.BatchNorm2d(base_channels), | |
| nn.ReLU(True), | |
| spectral_norm(nn.ConvTranspose2d(base_channels, img_channels, 4, 2, 1)), | |
| nn.Tanh() | |
| ) | |
| def forward(self, z): | |
| # Expect z shape: (B, z_dim, 1, 1) | |
| return self.net(z) | |
| class Discriminator(nn.Module): | |
| def __init__(self, img_channels=3, base_channels=64): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| spectral_norm(nn.Conv2d(img_channels, base_channels, 4, 2, 1)), | |
| nn.LeakyReLU(0.1, True), | |
| spectral_norm(nn.Conv2d(base_channels, base_channels*2, 4, 2, 1)), | |
| nn.LeakyReLU(0.1, True), | |
| # self‐attention at 32×32 | |
| Self_Attn(base_channels*2), | |
| spectral_norm(nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1)), | |
| nn.LeakyReLU(0.1, True), | |
| spectral_norm(nn.Conv2d(base_channels*4, 1, 4, 1, 0)) | |
| ) | |
| def forward(self, x): | |
| return self.net(x).view(-1) | |
| # ------------------------- | |
| # High-Level Wrapper | |
| # ------------------------- | |
| class SAGANModel(nn.Module): | |
| def __init__(self, z_dim=128, img_channels=3, base_channels=64): | |
| super().__init__() | |
| self.gen = Generator(z_dim, img_channels, base_channels) | |
| self.dis = Discriminator(img_channels, base_channels) | |
| def forward(self, z): | |
| # Only generator’s forward is typically used during inference | |
| return self.gen(z) | |