leharris3's picture
Minimal HF Space deployment with gradio 5.x fix
0917e8d
import torch
import torch.nn as nn
import torch.nn.functional as F
class Generator(nn.Module):
def __init__(self, config, gk, gs, gf, gp):
super(Generator, self).__init__()
self.config = config
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
self.no_layers = len(gk)
for lay, (k, s, p) in enumerate(zip(gk, gs, gp)):
if lay < self.no_layers - 2:
self.convs.append(
nn.ConvTranspose2d(gf[lay], gf[lay + 1], k, s, p, bias=False)
)
else:
self.convs.append(
nn.Conv2d(
gf[lay],
gf[lay + 1],
k,
s,
p,
bias=False,
padding_mode="reflect",
)
)
self.bns.append(nn.BatchNorm2d(gf[lay + 1]))
def forward(self, x: torch.Tensor):
count = 0
# layers = []
for conv, bn in zip(self.convs[:-1], self.bns[:-1]):
if count < self.no_layers - 2:
x = conv(x)
x = bn(x)
x = F.relu_(x)
else:
x = conv(x)
x = F.interpolate(
x,
[x.shape[-2] * 2 + 2, x.shape[-1] * 2 + 2],
mode="bilinear",
align_corners=False,
)
x = bn(x)
x = F.relu_(x)
count += 1
if self.config.image_type == "n-phase":
out = torch.softmax(self.convs[-1](x), dim=1)
else:
out = torch.sigmoid(self.convs[-1](x))
return out # bs x n x imsize x imsize x imsize
class Discriminator(nn.Module):
def __init__(self, dk, ds, dp, df):
super(Discriminator, self).__init__()
self.convs = nn.ModuleList()
for lay, (k, s, p) in enumerate(zip(dk, ds, dp)):
self.convs.append(nn.Conv2d(df[lay], df[lay + 1], k, s, p, bias=False))
def forward(self, x):
for conv in self.convs[:-1]:
x = F.relu_(conv(x))
x = self.convs[-1](x) # bs x 1 x 1
return x
def make_nets(config, training=True):
"""Creates Generator and Discriminator class objects from params either loaded from config object or params file.
:param config: a Config class object
:type config: Config
:param training: if training is True, params are loaded from Config object. If False, params are loaded from file, defaults to True
:type training: bool, optional
:return: Discriminator and Generator class objects
:rtype: Discriminator, Generator
"""
# save/load params
if training:
config.save()
else:
config.load()
dk, ds, df, dp, gk, gs, gf, gp = config.get_net_params()
# Make nets
return Discriminator(dk, ds, dp, df), Generator(config, gk, gs, gf, gp)