StyleGAN / layers.py
masterofaudio2077's picture
Upload 11 files
18dcb10 verified
Raw
History Blame Contribute Delete
12 kB
from imports import * # gets all the libs
from configuration import * # gets all the constants
class EqualizedConv2D(keras.layers.Layer):
def __init__(self, filters, kernel_size, gain=2.0, **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.gain = gain
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(self.kernel_size, self.kernel_size,
input_shape[-1], self.filters),
initializer='random_normal',
trainable=True
)
self.bias = self.add_weight(
shape=(self.filters,),
initializer='zeros',
trainable=True
)
fan_in = self.kernel_size * self.kernel_size * input_shape[-1]
self.c = float(np.sqrt(self.gain / fan_in))
def call(self, x):
x = keras.ops.conv(
x, self.kernel * self.c,
strides=1, padding='SAME', data_format='channels_last'
)
return x + self.bias
class PixelNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-8):
super().__init__()
self.epsilon = epsilon
def call(self, x):
x_sq = keras.ops.square(x)
mean_sq = keras.ops.mean(x_sq, axis=-1, keepdims=True)
x_norm = x / keras.ops.sqrt(mean_sq + self.epsilon)
return x_norm
class EqualizedDense(keras.layers.Layer):
def __init__(self, units, gain=2.0, lr_multiplier=1.0, **kwargs):
super().__init__(**kwargs)
self.units = units
self.gain = gain
self.lr_multiplier = lr_multiplier
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
initializer=keras.initializers.RandomNormal(
mean=0.0, stddev=1.0 / self.lr_multiplier
),
trainable=True
)
self.bias = self.add_weight(
shape=(self.units,),
initializer='zeros',
trainable=True
)
fan_in = input_shape[-1]
self.c = float(np.sqrt(self.gain / fan_in))
def call(self, x):
return (keras.ops.matmul(x, self.kernel * self.c) + self.bias) * self.lr_multiplier
class MinibatchStd(keras.layers.Layer):
def __init__(self, group_size=4, epsilon=1e-8):
super().__init__()
self.group_size = group_size
self.epsilon = epsilon
def call(self, x):
batch = x.shape[0] # ← static shape not tracer
h = x.shape[1]
w = x.shape[2]
c = x.shape[3]
group_size = min(self.group_size, batch) # ← plain Python min
y = keras.ops.reshape(x, [group_size, -1, h, w, c])
mean = keras.ops.mean(y, axis=0, keepdims=True)
var = keras.ops.mean(keras.ops.square(y - mean), axis=0)
std = keras.ops.sqrt(var + self.epsilon)
avg_std = keras.ops.mean(std, axis=[1, 2, 3], keepdims=True)
avg_std = keras.ops.tile(avg_std, [group_size, h, w, 1])
return keras.ops.concatenate([x, avg_std], axis=-1)
# ── gain=1.0 for linear projections ──────────────────────
def toRGB():
return EqualizedConv2D(filters=3, kernel_size=1, gain=1.0) # no activation
def fromRGB(filters):
return keras.Sequential([
EqualizedConv2D(filters, kernel_size=1, gain=2.0),
keras.layers.LeakyReLU(0.2), # activation here
])
class GenBlock(keras.layers.Layer):
def __init__(self, filters):
super().__init__()
self.conv1 = EqualizedConv2D(filters, kernel_size=4)
self.conv2 = EqualizedConv2D(filters, kernel_size=4)
self.act = keras.layers.LeakyReLU(0.2)
self.pnorm = PixelNorm()
def call(self, x):
x = self.act(self.conv1(x))
x = self.pnorm(x)
x = self.act(self.conv2(x))
x = self.pnorm(x)
return x
class DiscBlock(keras.layers.Layer):
def __init__(self, in_filters, out_filters, **kwargs):
super().__init__(**kwargs)
self.conv1 = EqualizedConv2D(in_filters, 4)
self.conv2 = EqualizedConv2D(out_filters, 4)
self.act = keras.layers.LeakyReLU(0.2)
self.pool = keras.layers.AveragePooling2D(pool_size=2)
def call(self, x):
x = self.act(self.conv1(x))
x = self.act(self.conv2(x))
x = self.pool(x)
return x
def build_mapping_network():
return keras.Sequential([
EqualizedDense(512, gain=2.0, lr_multiplier=0.01), # ← 100x slower
keras.layers.LeakyReLU(0.2),
EqualizedDense(512, gain=2.0, lr_multiplier=0.01),
keras.layers.LeakyReLU(0.2),
EqualizedDense(512, gain=2.0, lr_multiplier=0.01),
keras.layers.LeakyReLU(0.2),
EqualizedDense(512, gain=2.0, lr_multiplier=0.01),
keras.layers.LeakyReLU(0.2),
EqualizedDense(512, gain=2.0, lr_multiplier=0.01),
keras.layers.LeakyReLU(0.2),
EqualizedDense(512, gain=2.0, lr_multiplier=0.01),
keras.layers.LeakyReLU(0.2),
EqualizedDense(512, gain=2.0, lr_multiplier=0.01),
keras.layers.LeakyReLU(0.2),
EqualizedDense(512, gain=2.0, lr_multiplier=0.01),
keras.layers.LeakyReLU(0.2),
], name="mapping_network")
class AdaIN(keras.layers.Layer):
def __init__(self, channels, w_dim=512):
super().__init__()
self.channels = channels
self.w_dim = w_dim
# single affine layer β†’ 2 * channels, then split into ys and yb
self.style_transform = EqualizedDense(2 * channels, gain=1.0)
def call(self, x, w):
# ── explicit mean and std over H, W per channel ──────
mean = keras.ops.mean(x, axis=[1, 2], keepdims=True) # (batch, 1, 1, channels)
std = keras.ops.std(x, axis=[1, 2], keepdims=True) + 1e-8
normalized_x = (x - mean) / std
# ── single affine β†’ split into ys and yb ─────────────
style = self.style_transform(w) # (batch, 2 * channels)
ys, yb = keras.ops.split(style, 2, axis=-1) # each (batch, channels)
ys = keras.ops.reshape(ys, [-1, 1, 1, self.channels])
yb = keras.ops.reshape(yb, [-1, 1, 1, self.channels])
return ys * normalized_x + yb
# ─── Noise Injection ──────────────────────────────────────
class NoiseInjection(keras.layers.Layer):
def __init__(self, channels):
super().__init__()
self.B = self.add_weight(
shape=(1, 1, 1, channels),
initializer="zeros",
trainable=True,
name="noise_scale"
)
def call(self, x, rng_key):
batch = keras.ops.shape(x)[0]
h = keras.ops.shape(x)[1]
w = keras.ops.shape(x)[2]
noise = jax.random.normal(rng_key, shape=(batch, h, w, 1)) # ← pure JAX rng
return x + self.B * noise
# ─── StyleGAN Block ───────────────────────────────────────
class StyleBlock(keras.layers.Layer):
def __init__(self, channels, w_dim=512):
super().__init__()
self.conv = EqualizedConv2D(channels, 4)
self.noise = NoiseInjection(channels)
self.adain = AdaIN(channels, w_dim)
self.act = keras.layers.LeakyReLU(0.2)
def call(self, x, w,rng_key):
x = self.conv(x) # Conv 3x3
x = self.noise(x,rng_key)
x = self.act(x)
x = self.adain(x, w) # AdaIN with style from w
return x
class StyleGAN_Generator(keras.Model):
def __init__(self):
super().__init__()
self.mapping_network = build_mapping_network()
self.const = self.add_weight(
shape=(1, 4, 4, 512),
initializer="ones",
trainable=True,
name="const"
)
# ── register via setattr so Keras tracks them ─────
for res in RESOLUTIONS:
setattr(self, f"block_{res}_0", StyleBlock(RES_TO_FILTERS[res]))
setattr(self, f"block_{res}_1", StyleBlock(RES_TO_FILTERS[res]))
setattr(self, f"to_rgb_{res}", toRGB())
self.upsample = keras.layers.UpSampling2D(size=2,interpolation="bilinear")
self.current_resolution = 4
def call(self, z, alpha, rng_key):
w = self.mapping_network(z)
batch = keras.ops.shape(z)[0]
x = keras.ops.tile(self.const, [batch, 1, 1, 1])
rng_key, subkey = jax.random.split(rng_key)
x = getattr(self, "block_4_0")(x, w, subkey)
rng_key, subkey = jax.random.split(rng_key)
x = getattr(self, "block_4_1")(x, w, subkey)
if self.current_resolution == 4:
return getattr(self, "to_rgb_4")(x)
for res in RESOLUTIONS[1:]:
x_prev = x
x = self.upsample(x)
rng_key, subkey = jax.random.split(rng_key)
x = getattr(self, f"block_{res}_0")(x, w, subkey)
rng_key, subkey = jax.random.split(rng_key)
x = getattr(self, f"block_{res}_1")(x, w, subkey)
if res == self.current_resolution:
old_rgb = getattr(self, f"to_rgb_{res // 2}")(self.upsample(x_prev))
new_rgb = getattr(self, f"to_rgb_{res}")(x)
return (1 - alpha) * old_rgb + alpha * new_rgb
return getattr(self, f"to_rgb_{self.current_resolution}")(x)
class ProGAN_Discriminator(keras.Model):
def __init__(self):
super().__init__()
# ── register via setattr so Keras tracks them ─────
for res in RESOLUTIONS[1:]:
setattr(self, f"block_{res}", DiscBlock(RES_TO_FILTERS[res], RES_TO_FILTERS[res // 2]))
for res in RESOLUTIONS:
setattr(self, f"from_rgb_{res}", fromRGB(RES_TO_FILTERS[res]))
self.minibatch_std = MinibatchStd()
self.final_conv = EqualizedConv2D(512, kernel_size=3)
self.final_dense_1 = EqualizedDense(512) # ← missing intermediate dense
self.final_dense_2 = EqualizedDense(1)
self.flatten = keras.layers.Flatten()
self.act = keras.layers.LeakyReLU(0.2)
self.downsample = keras.layers.AveragePooling2D(pool_size=2)
self.current_resolution = 4
def call(self, img, alpha):
if self.current_resolution == 4:
x = getattr(self, "from_rgb_4")(img)
x = self.minibatch_std(x)
x = self.act(self.final_conv(x))
x = self.flatten(x)
x = self.act(self.final_dense_1(x)) # ← intermediate dense
return self.final_dense_2(x)
cur_res = self.current_resolution
prev_res = cur_res // 2
x_new = getattr(self, f"from_rgb_{cur_res}")(img)
x_new = getattr(self, f"block_{cur_res}")(x_new)
x_old = self.downsample(img)
x_old = getattr(self, f"from_rgb_{prev_res}")(x_old)
x = (1 - alpha) * x_old + alpha * x_new
for res in reversed(RESOLUTIONS[1:]):
if res >= cur_res:
continue
x = getattr(self, f"block_{res}")(x)
x = self.minibatch_std(x)
x = self.act(self.final_conv(x))
x = self.flatten(x)
x = self.act(self.final_dense_1(x)) # ← intermediate dense
return self.final_dense_2(x)