SkillForge45 commited on
Commit
146ebe7
·
verified ·
1 Parent(s): 81fbc0f

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +103 -0
model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import matplotlib.pyplot as plt
4
+
5
+ # Parameters
6
+ IMG_SIZE = 64
7
+ LATENT_DIM = 128 # Dimension of the input noise
8
+ BATCH_SIZE = 1
9
+ STYLE_DIM = 3 # Number of styles: circles, squares, mixed
10
+
11
+ # Generator with style control
12
+ def build_generator():
13
+ # Input layers
14
+ noise_input = tf.keras.layers.Input(shape=(LATENT_DIM,))
15
+ style_input = tf.keras.layers.Input(shape=(STYLE_DIM,))
16
+
17
+ # Concatenate noise and style
18
+ x = tf.keras.layers.concatenate([noise_input, style_input])
19
+ x = tf.keras.layers.Dense(8 * 8 * 64, activation='relu')(x)
20
+ x = tf.keras.layers.Reshape((8, 8, 64))(x)
21
+
22
+ # Transposed convolutions
23
+ x = tf.keras.layers.Conv2DTranspose(64, (4,4), strides=2, padding='same', activation='relu')(x)
24
+ x = tf.keras.layers.Conv2DTranspose(32, (4,4), strides=2, padding='same', activation='relu')(x)
25
+ x = tf.keras.layers.Conv2DTranspose(3, (4,4), strides=2, padding='same', activation='sigmoid')(x) # RGB [0,1]
26
+
27
+ return tf.keras.Model(inputs=[noise_input, style_input], outputs=x)
28
+
29
+ # Discriminator (same as before)
30
+ def build_discriminator():
31
+ model = tf.keras.Sequential([
32
+ tf.keras.layers.Conv2D(32, (3,3), strides=2, padding='same', input_shape=(IMG_SIZE, IMG_SIZE, 3)),
33
+ tf.keras.layers.LeakyReLU(0.2),
34
+ tf.keras.layers.Conv2D(64, (3,3), strides=2, padding='same'),
35
+ tf.keras.layers.LeakyReLU(0.2),
36
+ tf.keras.layers.Flatten(),
37
+ tf.keras.layers.Dense(1, activation='sigmoid')
38
+ ])
39
+ return model
40
+
41
+ # Create models
42
+ generator = build_generator()
43
+ discriminator = build_discriminator()
44
+
45
+ # Loss function and optimizers
46
+ cross_entropy = tf.keras.losses.BinaryCrossentropy()
47
+ g_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
48
+ d_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
49
+
50
+ # Style encodings
51
+ STYLES = {
52
+ 'circles': [1., 0., 0.],
53
+ 'squares': [0., 1., 0.],
54
+ 'mixed': [0., 0., 1.]
55
+ }
56
+
57
+ def generate_with_style(style_name):
58
+ style = STYLES[style_name]
59
+ test_noise = tf.random.normal([1, LATENT_DIM])
60
+ style_input = tf.constant([style], dtype=tf.float32)
61
+ generated_img = generator([test_noise, style_input], training=False)[0]
62
+ plt.imshow(generated_img)
63
+ plt.title(f"Style: {style_name}")
64
+ plt.axis('off')
65
+ plt.show()
66
+
67
+ # Training function with style conditioning
68
+ @tf.function
69
+ def train_step():
70
+ # 1. Generate random noise and random style
71
+ noise = tf.random.normal([BATCH_SIZE, LATENT_DIM])
72
+ style = tf.one_hot(tf.random.uniform([BATCH_SIZE], maxval=STYLE_DIM, dtype=tf.int32), STYLE_DIM)
73
+
74
+ # 2. Train discriminator
75
+ with tf.GradientTape() as d_tape:
76
+ generated_images = generator([noise, style], training=True)
77
+ real_output = discriminator(tf.random.uniform((BATCH_SIZE, IMG_SIZE, IMG_SIZE, 3)), training=True)
78
+ fake_output = discriminator(generated_images, training=True)
79
+ d_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
80
+
81
+ d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)
82
+ d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))
83
+
84
+ # 3. Train generator
85
+ with tf.GradientTape() as g_tape:
86
+ generated_images = generator([noise, style], training=True)
87
+ fake_output = discriminator(generated_images, training=True)
88
+ g_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
89
+
90
+ g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)
91
+ g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))
92
+
93
+ return d_loss, g_loss
94
+
95
+ # Training loop
96
+ for epoch in range(50):
97
+ d_loss, g_loss = train_step()
98
+ if epoch % 10 == 0:
99
+ print(f"Epoch {epoch}, D Loss: {d_loss:.3f}, G Loss: {g_loss:.3f}")
100
+
101
+ # Generate samples for each style
102
+ for style_name in STYLES.keys():
103
+ generate_with_style(style_name)