|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
IMG_SIZE = 64 |
|
|
LATENT_DIM = 128 |
|
|
BATCH_SIZE = 1 |
|
|
STYLE_DIM = 3 |
|
|
|
|
|
|
|
|
def build_generator(): |
|
|
|
|
|
noise_input = tf.keras.layers.Input(shape=(LATENT_DIM,)) |
|
|
style_input = tf.keras.layers.Input(shape=(STYLE_DIM,)) |
|
|
|
|
|
|
|
|
x = tf.keras.layers.concatenate([noise_input, style_input]) |
|
|
x = tf.keras.layers.Dense(8 * 8 * 64, activation='relu')(x) |
|
|
x = tf.keras.layers.Reshape((8, 8, 64))(x) |
|
|
|
|
|
|
|
|
x = tf.keras.layers.Conv2DTranspose(64, (4,4), strides=2, padding='same', activation='relu')(x) |
|
|
x = tf.keras.layers.Conv2DTranspose(32, (4,4), strides=2, padding='same', activation='relu')(x) |
|
|
x = tf.keras.layers.Conv2DTranspose(3, (4,4), strides=2, padding='same', activation='sigmoid')(x) |
|
|
|
|
|
return tf.keras.Model(inputs=[noise_input, style_input], outputs=x) |
|
|
|
|
|
|
|
|
def build_discriminator(): |
|
|
model = tf.keras.Sequential([ |
|
|
tf.keras.layers.Conv2D(32, (3,3), strides=2, padding='same', input_shape=(IMG_SIZE, IMG_SIZE, 3)), |
|
|
tf.keras.layers.LeakyReLU(0.2), |
|
|
tf.keras.layers.Conv2D(64, (3,3), strides=2, padding='same'), |
|
|
tf.keras.layers.LeakyReLU(0.2), |
|
|
tf.keras.layers.Flatten(), |
|
|
tf.keras.layers.Dense(1, activation='sigmoid') |
|
|
]) |
|
|
return model |
|
|
|
|
|
|
|
|
generator = build_generator() |
|
|
discriminator = build_discriminator() |
|
|
|
|
|
|
|
|
cross_entropy = tf.keras.losses.BinaryCrossentropy() |
|
|
g_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5) |
|
|
d_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5) |
|
|
|
|
|
|
|
|
STYLES = { |
|
|
'circles': [1., 0., 0.], |
|
|
'squares': [0., 1., 0.], |
|
|
'mixed': [0., 0., 1.] |
|
|
} |
|
|
|
|
|
def generate_with_style(style_name): |
|
|
style = STYLES[style_name] |
|
|
test_noise = tf.random.normal([1, LATENT_DIM]) |
|
|
style_input = tf.constant([style], dtype=tf.float32) |
|
|
generated_img = generator([test_noise, style_input], training=False)[0] |
|
|
plt.imshow(generated_img) |
|
|
plt.title(f"Style: {style_name}") |
|
|
plt.axis('off') |
|
|
plt.show() |
|
|
|
|
|
|
|
|
@tf.function |
|
|
def train_step(): |
|
|
|
|
|
noise = tf.random.normal([BATCH_SIZE, LATENT_DIM]) |
|
|
style = tf.one_hot(tf.random.uniform([BATCH_SIZE], maxval=STYLE_DIM, dtype=tf.int32), STYLE_DIM) |
|
|
|
|
|
|
|
|
with tf.GradientTape() as d_tape: |
|
|
generated_images = generator([noise, style], training=True) |
|
|
real_output = discriminator(tf.random.uniform((BATCH_SIZE, IMG_SIZE, IMG_SIZE, 3)), training=True) |
|
|
fake_output = discriminator(generated_images, training=True) |
|
|
d_loss = cross_entropy(tf.ones_like(fake_output), fake_output) |
|
|
|
|
|
d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables) |
|
|
d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables)) |
|
|
|
|
|
|
|
|
with tf.GradientTape() as g_tape: |
|
|
generated_images = generator([noise, style], training=True) |
|
|
fake_output = discriminator(generated_images, training=True) |
|
|
g_loss = cross_entropy(tf.ones_like(fake_output), fake_output) |
|
|
|
|
|
g_gradients = g_tape.gradient(g_loss, generator.trainable_variables) |
|
|
g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables)) |
|
|
|
|
|
return d_loss, g_loss |
|
|
|
|
|
|
|
|
for epoch in range(50): |
|
|
d_loss, g_loss = train_step() |
|
|
if epoch % 10 == 0: |
|
|
print(f"Epoch {epoch}, D Loss: {d_loss:.3f}, G Loss: {g_loss:.3f}") |
|
|
|
|
|
|
|
|
for style_name in STYLES.keys(): |
|
|
generate_with_style(style_name) |