from imports import * # gets all the libs from configuration import * # gets all the constants class EqualizedConv2D(keras.layers.Layer): def __init__(self, filters, kernel_size, gain=2.0, **kwargs): super().__init__(**kwargs) self.filters = filters self.kernel_size = kernel_size self.gain = gain def build(self, input_shape): self.kernel = self.add_weight( shape=(self.kernel_size, self.kernel_size, input_shape[-1], self.filters), initializer='random_normal', trainable=True ) self.bias = self.add_weight( shape=(self.filters,), initializer='zeros', trainable=True ) fan_in = self.kernel_size * self.kernel_size * input_shape[-1] self.c = float(np.sqrt(self.gain / fan_in)) def call(self, x): x = keras.ops.conv( x, self.kernel * self.c, strides=1, padding='SAME', data_format='channels_last' ) return x + self.bias class PixelNorm(keras.layers.Layer): def __init__(self, epsilon=1e-8): super().__init__() self.epsilon = epsilon def call(self, x): x_sq = keras.ops.square(x) mean_sq = keras.ops.mean(x_sq, axis=-1, keepdims=True) x_norm = x / keras.ops.sqrt(mean_sq + self.epsilon) return x_norm class EqualizedDense(keras.layers.Layer): def __init__(self, units, gain=2.0, lr_multiplier=1.0, **kwargs): super().__init__(**kwargs) self.units = units self.gain = gain self.lr_multiplier = lr_multiplier def build(self, input_shape): self.kernel = self.add_weight( shape=(input_shape[-1], self.units), initializer=keras.initializers.RandomNormal( mean=0.0, stddev=1.0 / self.lr_multiplier ), trainable=True ) self.bias = self.add_weight( shape=(self.units,), initializer='zeros', trainable=True ) fan_in = input_shape[-1] self.c = float(np.sqrt(self.gain / fan_in)) def call(self, x): return (keras.ops.matmul(x, self.kernel * self.c) + self.bias) * self.lr_multiplier class MinibatchStd(keras.layers.Layer): def __init__(self, group_size=4, epsilon=1e-8): super().__init__() self.group_size = group_size self.epsilon = epsilon def call(self, x): batch = x.shape[0] # ← static shape not tracer h = x.shape[1] w = x.shape[2] c = x.shape[3] group_size = min(self.group_size, batch) # ← plain Python min y = keras.ops.reshape(x, [group_size, -1, h, w, c]) mean = keras.ops.mean(y, axis=0, keepdims=True) var = keras.ops.mean(keras.ops.square(y - mean), axis=0) std = keras.ops.sqrt(var + self.epsilon) avg_std = keras.ops.mean(std, axis=[1, 2, 3], keepdims=True) avg_std = keras.ops.tile(avg_std, [group_size, h, w, 1]) return keras.ops.concatenate([x, avg_std], axis=-1) # ── gain=1.0 for linear projections ────────────────────── def toRGB(): return EqualizedConv2D(filters=3, kernel_size=1, gain=1.0) # no activation def fromRGB(filters): return keras.Sequential([ EqualizedConv2D(filters, kernel_size=1, gain=2.0), keras.layers.LeakyReLU(0.2), # activation here ]) class GenBlock(keras.layers.Layer): def __init__(self, filters): super().__init__() self.conv1 = EqualizedConv2D(filters, kernel_size=4) self.conv2 = EqualizedConv2D(filters, kernel_size=4) self.act = keras.layers.LeakyReLU(0.2) self.pnorm = PixelNorm() def call(self, x): x = self.act(self.conv1(x)) x = self.pnorm(x) x = self.act(self.conv2(x)) x = self.pnorm(x) return x class DiscBlock(keras.layers.Layer): def __init__(self, in_filters, out_filters, **kwargs): super().__init__(**kwargs) self.conv1 = EqualizedConv2D(in_filters, 4) self.conv2 = EqualizedConv2D(out_filters, 4) self.act = keras.layers.LeakyReLU(0.2) self.pool = keras.layers.AveragePooling2D(pool_size=2) def call(self, x): x = self.act(self.conv1(x)) x = self.act(self.conv2(x)) x = self.pool(x) return x def build_mapping_network(): return keras.Sequential([ EqualizedDense(512, gain=2.0, lr_multiplier=0.01), # ← 100x slower keras.layers.LeakyReLU(0.2), EqualizedDense(512, gain=2.0, lr_multiplier=0.01), keras.layers.LeakyReLU(0.2), EqualizedDense(512, gain=2.0, lr_multiplier=0.01), keras.layers.LeakyReLU(0.2), EqualizedDense(512, gain=2.0, lr_multiplier=0.01), keras.layers.LeakyReLU(0.2), EqualizedDense(512, gain=2.0, lr_multiplier=0.01), keras.layers.LeakyReLU(0.2), EqualizedDense(512, gain=2.0, lr_multiplier=0.01), keras.layers.LeakyReLU(0.2), EqualizedDense(512, gain=2.0, lr_multiplier=0.01), keras.layers.LeakyReLU(0.2), EqualizedDense(512, gain=2.0, lr_multiplier=0.01), keras.layers.LeakyReLU(0.2), ], name="mapping_network") class AdaIN(keras.layers.Layer): def __init__(self, channels, w_dim=512): super().__init__() self.channels = channels self.w_dim = w_dim # single affine layer → 2 * channels, then split into ys and yb self.style_transform = EqualizedDense(2 * channels, gain=1.0) def call(self, x, w): # ── explicit mean and std over H, W per channel ────── mean = keras.ops.mean(x, axis=[1, 2], keepdims=True) # (batch, 1, 1, channels) std = keras.ops.std(x, axis=[1, 2], keepdims=True) + 1e-8 normalized_x = (x - mean) / std # ── single affine → split into ys and yb ───────────── style = self.style_transform(w) # (batch, 2 * channels) ys, yb = keras.ops.split(style, 2, axis=-1) # each (batch, channels) ys = keras.ops.reshape(ys, [-1, 1, 1, self.channels]) yb = keras.ops.reshape(yb, [-1, 1, 1, self.channels]) return ys * normalized_x + yb # ─── Noise Injection ────────────────────────────────────── class NoiseInjection(keras.layers.Layer): def __init__(self, channels): super().__init__() self.B = self.add_weight( shape=(1, 1, 1, channels), initializer="zeros", trainable=True, name="noise_scale" ) def call(self, x, rng_key): batch = keras.ops.shape(x)[0] h = keras.ops.shape(x)[1] w = keras.ops.shape(x)[2] noise = jax.random.normal(rng_key, shape=(batch, h, w, 1)) # ← pure JAX rng return x + self.B * noise # ─── StyleGAN Block ─────────────────────────────────────── class StyleBlock(keras.layers.Layer): def __init__(self, channels, w_dim=512): super().__init__() self.conv = EqualizedConv2D(channels, 4) self.noise = NoiseInjection(channels) self.adain = AdaIN(channels, w_dim) self.act = keras.layers.LeakyReLU(0.2) def call(self, x, w,rng_key): x = self.conv(x) # Conv 3x3 x = self.noise(x,rng_key) x = self.act(x) x = self.adain(x, w) # AdaIN with style from w return x class StyleGAN_Generator(keras.Model): def __init__(self): super().__init__() self.mapping_network = build_mapping_network() self.const = self.add_weight( shape=(1, 4, 4, 512), initializer="ones", trainable=True, name="const" ) # ── register via setattr so Keras tracks them ───── for res in RESOLUTIONS: setattr(self, f"block_{res}_0", StyleBlock(RES_TO_FILTERS[res])) setattr(self, f"block_{res}_1", StyleBlock(RES_TO_FILTERS[res])) setattr(self, f"to_rgb_{res}", toRGB()) self.upsample = keras.layers.UpSampling2D(size=2,interpolation="bilinear") self.current_resolution = 4 def call(self, z, alpha, rng_key): w = self.mapping_network(z) batch = keras.ops.shape(z)[0] x = keras.ops.tile(self.const, [batch, 1, 1, 1]) rng_key, subkey = jax.random.split(rng_key) x = getattr(self, "block_4_0")(x, w, subkey) rng_key, subkey = jax.random.split(rng_key) x = getattr(self, "block_4_1")(x, w, subkey) if self.current_resolution == 4: return getattr(self, "to_rgb_4")(x) for res in RESOLUTIONS[1:]: x_prev = x x = self.upsample(x) rng_key, subkey = jax.random.split(rng_key) x = getattr(self, f"block_{res}_0")(x, w, subkey) rng_key, subkey = jax.random.split(rng_key) x = getattr(self, f"block_{res}_1")(x, w, subkey) if res == self.current_resolution: old_rgb = getattr(self, f"to_rgb_{res // 2}")(self.upsample(x_prev)) new_rgb = getattr(self, f"to_rgb_{res}")(x) return (1 - alpha) * old_rgb + alpha * new_rgb return getattr(self, f"to_rgb_{self.current_resolution}")(x) class ProGAN_Discriminator(keras.Model): def __init__(self): super().__init__() # ── register via setattr so Keras tracks them ───── for res in RESOLUTIONS[1:]: setattr(self, f"block_{res}", DiscBlock(RES_TO_FILTERS[res], RES_TO_FILTERS[res // 2])) for res in RESOLUTIONS: setattr(self, f"from_rgb_{res}", fromRGB(RES_TO_FILTERS[res])) self.minibatch_std = MinibatchStd() self.final_conv = EqualizedConv2D(512, kernel_size=3) self.final_dense_1 = EqualizedDense(512) # ← missing intermediate dense self.final_dense_2 = EqualizedDense(1) self.flatten = keras.layers.Flatten() self.act = keras.layers.LeakyReLU(0.2) self.downsample = keras.layers.AveragePooling2D(pool_size=2) self.current_resolution = 4 def call(self, img, alpha): if self.current_resolution == 4: x = getattr(self, "from_rgb_4")(img) x = self.minibatch_std(x) x = self.act(self.final_conv(x)) x = self.flatten(x) x = self.act(self.final_dense_1(x)) # ← intermediate dense return self.final_dense_2(x) cur_res = self.current_resolution prev_res = cur_res // 2 x_new = getattr(self, f"from_rgb_{cur_res}")(img) x_new = getattr(self, f"block_{cur_res}")(x_new) x_old = self.downsample(img) x_old = getattr(self, f"from_rgb_{prev_res}")(x_old) x = (1 - alpha) * x_old + alpha * x_new for res in reversed(RESOLUTIONS[1:]): if res >= cur_res: continue x = getattr(self, f"block_{res}")(x) x = self.minibatch_std(x) x = self.act(self.final_conv(x)) x = self.flatten(x) x = self.act(self.final_dense_1(x)) # ← intermediate dense return self.final_dense_2(x)