Spaces:
Sleeping
Sleeping
| from imports import * # gets all the libs | |
| from configuration import * # gets all the constants | |
| class EqualizedConv2D(keras.layers.Layer): | |
| def __init__(self, filters, kernel_size, gain=2.0, **kwargs): | |
| super().__init__(**kwargs) | |
| self.filters = filters | |
| self.kernel_size = kernel_size | |
| self.gain = gain | |
| def build(self, input_shape): | |
| self.kernel = self.add_weight( | |
| shape=(self.kernel_size, self.kernel_size, | |
| input_shape[-1], self.filters), | |
| initializer='random_normal', | |
| trainable=True | |
| ) | |
| self.bias = self.add_weight( | |
| shape=(self.filters,), | |
| initializer='zeros', | |
| trainable=True | |
| ) | |
| fan_in = self.kernel_size * self.kernel_size * input_shape[-1] | |
| self.c = float(np.sqrt(self.gain / fan_in)) | |
| def call(self, x): | |
| x = keras.ops.conv( | |
| x, self.kernel * self.c, | |
| strides=1, padding='SAME', data_format='channels_last' | |
| ) | |
| return x + self.bias | |
| class PixelNorm(keras.layers.Layer): | |
| def __init__(self, epsilon=1e-8): | |
| super().__init__() | |
| self.epsilon = epsilon | |
| def call(self, x): | |
| x_sq = keras.ops.square(x) | |
| mean_sq = keras.ops.mean(x_sq, axis=-1, keepdims=True) | |
| x_norm = x / keras.ops.sqrt(mean_sq + self.epsilon) | |
| return x_norm | |
| class EqualizedDense(keras.layers.Layer): | |
| def __init__(self, units, gain=2.0, lr_multiplier=1.0, **kwargs): | |
| super().__init__(**kwargs) | |
| self.units = units | |
| self.gain = gain | |
| self.lr_multiplier = lr_multiplier | |
| def build(self, input_shape): | |
| self.kernel = self.add_weight( | |
| shape=(input_shape[-1], self.units), | |
| initializer=keras.initializers.RandomNormal( | |
| mean=0.0, stddev=1.0 / self.lr_multiplier | |
| ), | |
| trainable=True | |
| ) | |
| self.bias = self.add_weight( | |
| shape=(self.units,), | |
| initializer='zeros', | |
| trainable=True | |
| ) | |
| fan_in = input_shape[-1] | |
| self.c = float(np.sqrt(self.gain / fan_in)) | |
| def call(self, x): | |
| return (keras.ops.matmul(x, self.kernel * self.c) + self.bias) * self.lr_multiplier | |
| class MinibatchStd(keras.layers.Layer): | |
| def __init__(self, group_size=4, epsilon=1e-8): | |
| super().__init__() | |
| self.group_size = group_size | |
| self.epsilon = epsilon | |
| def call(self, x): | |
| batch = x.shape[0] # β static shape not tracer | |
| h = x.shape[1] | |
| w = x.shape[2] | |
| c = x.shape[3] | |
| group_size = min(self.group_size, batch) # β plain Python min | |
| y = keras.ops.reshape(x, [group_size, -1, h, w, c]) | |
| mean = keras.ops.mean(y, axis=0, keepdims=True) | |
| var = keras.ops.mean(keras.ops.square(y - mean), axis=0) | |
| std = keras.ops.sqrt(var + self.epsilon) | |
| avg_std = keras.ops.mean(std, axis=[1, 2, 3], keepdims=True) | |
| avg_std = keras.ops.tile(avg_std, [group_size, h, w, 1]) | |
| return keras.ops.concatenate([x, avg_std], axis=-1) | |
| # ββ gain=1.0 for linear projections ββββββββββββββββββββββ | |
| def toRGB(): | |
| return EqualizedConv2D(filters=3, kernel_size=1, gain=1.0) # no activation | |
| def fromRGB(filters): | |
| return keras.Sequential([ | |
| EqualizedConv2D(filters, kernel_size=1, gain=2.0), | |
| keras.layers.LeakyReLU(0.2), # activation here | |
| ]) | |
| class GenBlock(keras.layers.Layer): | |
| def __init__(self, filters): | |
| super().__init__() | |
| self.conv1 = EqualizedConv2D(filters, kernel_size=4) | |
| self.conv2 = EqualizedConv2D(filters, kernel_size=4) | |
| self.act = keras.layers.LeakyReLU(0.2) | |
| self.pnorm = PixelNorm() | |
| def call(self, x): | |
| x = self.act(self.conv1(x)) | |
| x = self.pnorm(x) | |
| x = self.act(self.conv2(x)) | |
| x = self.pnorm(x) | |
| return x | |
| class DiscBlock(keras.layers.Layer): | |
| def __init__(self, in_filters, out_filters, **kwargs): | |
| super().__init__(**kwargs) | |
| self.conv1 = EqualizedConv2D(in_filters, 4) | |
| self.conv2 = EqualizedConv2D(out_filters, 4) | |
| self.act = keras.layers.LeakyReLU(0.2) | |
| self.pool = keras.layers.AveragePooling2D(pool_size=2) | |
| def call(self, x): | |
| x = self.act(self.conv1(x)) | |
| x = self.act(self.conv2(x)) | |
| x = self.pool(x) | |
| return x | |
| def build_mapping_network(): | |
| return keras.Sequential([ | |
| EqualizedDense(512, gain=2.0, lr_multiplier=0.01), # β 100x slower | |
| keras.layers.LeakyReLU(0.2), | |
| EqualizedDense(512, gain=2.0, lr_multiplier=0.01), | |
| keras.layers.LeakyReLU(0.2), | |
| EqualizedDense(512, gain=2.0, lr_multiplier=0.01), | |
| keras.layers.LeakyReLU(0.2), | |
| EqualizedDense(512, gain=2.0, lr_multiplier=0.01), | |
| keras.layers.LeakyReLU(0.2), | |
| EqualizedDense(512, gain=2.0, lr_multiplier=0.01), | |
| keras.layers.LeakyReLU(0.2), | |
| EqualizedDense(512, gain=2.0, lr_multiplier=0.01), | |
| keras.layers.LeakyReLU(0.2), | |
| EqualizedDense(512, gain=2.0, lr_multiplier=0.01), | |
| keras.layers.LeakyReLU(0.2), | |
| EqualizedDense(512, gain=2.0, lr_multiplier=0.01), | |
| keras.layers.LeakyReLU(0.2), | |
| ], name="mapping_network") | |
| class AdaIN(keras.layers.Layer): | |
| def __init__(self, channels, w_dim=512): | |
| super().__init__() | |
| self.channels = channels | |
| self.w_dim = w_dim | |
| # single affine layer β 2 * channels, then split into ys and yb | |
| self.style_transform = EqualizedDense(2 * channels, gain=1.0) | |
| def call(self, x, w): | |
| # ββ explicit mean and std over H, W per channel ββββββ | |
| mean = keras.ops.mean(x, axis=[1, 2], keepdims=True) # (batch, 1, 1, channels) | |
| std = keras.ops.std(x, axis=[1, 2], keepdims=True) + 1e-8 | |
| normalized_x = (x - mean) / std | |
| # ββ single affine β split into ys and yb βββββββββββββ | |
| style = self.style_transform(w) # (batch, 2 * channels) | |
| ys, yb = keras.ops.split(style, 2, axis=-1) # each (batch, channels) | |
| ys = keras.ops.reshape(ys, [-1, 1, 1, self.channels]) | |
| yb = keras.ops.reshape(yb, [-1, 1, 1, self.channels]) | |
| return ys * normalized_x + yb | |
| # βββ Noise Injection ββββββββββββββββββββββββββββββββββββββ | |
| class NoiseInjection(keras.layers.Layer): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.B = self.add_weight( | |
| shape=(1, 1, 1, channels), | |
| initializer="zeros", | |
| trainable=True, | |
| name="noise_scale" | |
| ) | |
| def call(self, x, rng_key): | |
| batch = keras.ops.shape(x)[0] | |
| h = keras.ops.shape(x)[1] | |
| w = keras.ops.shape(x)[2] | |
| noise = jax.random.normal(rng_key, shape=(batch, h, w, 1)) # β pure JAX rng | |
| return x + self.B * noise | |
| # βββ StyleGAN Block βββββββββββββββββββββββββββββββββββββββ | |
| class StyleBlock(keras.layers.Layer): | |
| def __init__(self, channels, w_dim=512): | |
| super().__init__() | |
| self.conv = EqualizedConv2D(channels, 4) | |
| self.noise = NoiseInjection(channels) | |
| self.adain = AdaIN(channels, w_dim) | |
| self.act = keras.layers.LeakyReLU(0.2) | |
| def call(self, x, w,rng_key): | |
| x = self.conv(x) # Conv 3x3 | |
| x = self.noise(x,rng_key) | |
| x = self.act(x) | |
| x = self.adain(x, w) # AdaIN with style from w | |
| return x | |
| class StyleGAN_Generator(keras.Model): | |
| def __init__(self): | |
| super().__init__() | |
| self.mapping_network = build_mapping_network() | |
| self.const = self.add_weight( | |
| shape=(1, 4, 4, 512), | |
| initializer="ones", | |
| trainable=True, | |
| name="const" | |
| ) | |
| # ββ register via setattr so Keras tracks them βββββ | |
| for res in RESOLUTIONS: | |
| setattr(self, f"block_{res}_0", StyleBlock(RES_TO_FILTERS[res])) | |
| setattr(self, f"block_{res}_1", StyleBlock(RES_TO_FILTERS[res])) | |
| setattr(self, f"to_rgb_{res}", toRGB()) | |
| self.upsample = keras.layers.UpSampling2D(size=2,interpolation="bilinear") | |
| self.current_resolution = 4 | |
| def call(self, z, alpha, rng_key): | |
| w = self.mapping_network(z) | |
| batch = keras.ops.shape(z)[0] | |
| x = keras.ops.tile(self.const, [batch, 1, 1, 1]) | |
| rng_key, subkey = jax.random.split(rng_key) | |
| x = getattr(self, "block_4_0")(x, w, subkey) | |
| rng_key, subkey = jax.random.split(rng_key) | |
| x = getattr(self, "block_4_1")(x, w, subkey) | |
| if self.current_resolution == 4: | |
| return getattr(self, "to_rgb_4")(x) | |
| for res in RESOLUTIONS[1:]: | |
| x_prev = x | |
| x = self.upsample(x) | |
| rng_key, subkey = jax.random.split(rng_key) | |
| x = getattr(self, f"block_{res}_0")(x, w, subkey) | |
| rng_key, subkey = jax.random.split(rng_key) | |
| x = getattr(self, f"block_{res}_1")(x, w, subkey) | |
| if res == self.current_resolution: | |
| old_rgb = getattr(self, f"to_rgb_{res // 2}")(self.upsample(x_prev)) | |
| new_rgb = getattr(self, f"to_rgb_{res}")(x) | |
| return (1 - alpha) * old_rgb + alpha * new_rgb | |
| return getattr(self, f"to_rgb_{self.current_resolution}")(x) | |
| class ProGAN_Discriminator(keras.Model): | |
| def __init__(self): | |
| super().__init__() | |
| # ββ register via setattr so Keras tracks them βββββ | |
| for res in RESOLUTIONS[1:]: | |
| setattr(self, f"block_{res}", DiscBlock(RES_TO_FILTERS[res], RES_TO_FILTERS[res // 2])) | |
| for res in RESOLUTIONS: | |
| setattr(self, f"from_rgb_{res}", fromRGB(RES_TO_FILTERS[res])) | |
| self.minibatch_std = MinibatchStd() | |
| self.final_conv = EqualizedConv2D(512, kernel_size=3) | |
| self.final_dense_1 = EqualizedDense(512) # β missing intermediate dense | |
| self.final_dense_2 = EqualizedDense(1) | |
| self.flatten = keras.layers.Flatten() | |
| self.act = keras.layers.LeakyReLU(0.2) | |
| self.downsample = keras.layers.AveragePooling2D(pool_size=2) | |
| self.current_resolution = 4 | |
| def call(self, img, alpha): | |
| if self.current_resolution == 4: | |
| x = getattr(self, "from_rgb_4")(img) | |
| x = self.minibatch_std(x) | |
| x = self.act(self.final_conv(x)) | |
| x = self.flatten(x) | |
| x = self.act(self.final_dense_1(x)) # β intermediate dense | |
| return self.final_dense_2(x) | |
| cur_res = self.current_resolution | |
| prev_res = cur_res // 2 | |
| x_new = getattr(self, f"from_rgb_{cur_res}")(img) | |
| x_new = getattr(self, f"block_{cur_res}")(x_new) | |
| x_old = self.downsample(img) | |
| x_old = getattr(self, f"from_rgb_{prev_res}")(x_old) | |
| x = (1 - alpha) * x_old + alpha * x_new | |
| for res in reversed(RESOLUTIONS[1:]): | |
| if res >= cur_res: | |
| continue | |
| x = getattr(self, f"block_{res}")(x) | |
| x = self.minibatch_std(x) | |
| x = self.act(self.final_conv(x)) | |
| x = self.flatten(x) | |
| x = self.act(self.final_dense_1(x)) # β intermediate dense | |
| return self.final_dense_2(x) |