Image-to-Image
Adapters
chemistry
art
K1Z3M1112 commited on
Commit
69d5ab4
·
verified ·
1 Parent(s): 1c54848

Upload 6 files

Browse files
Files changed (6) hide show
  1. config.py +46 -0
  2. dataset.py +51 -0
  3. inference.py +126 -0
  4. model.py +195 -0
  5. train.py +112 -0
  6. utils.py +87 -0
config.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ # Define the directories
9
+ LINE_ART_DIR = "train_images/line_arts"
10
+ COLORED_DIR = "train_images/colored"
11
+
12
+ # Ensure directories exist
13
+ os.makedirs(LINE_ART_DIR, exist_ok=True)
14
+ os.makedirs(COLORED_DIR, exist_ok=True)
15
+
16
+ # Function to download and process images
17
+ def download_and_process_images(manga_title, url):
18
+ response = requests.get(url)
19
+ soup = BeautifulSoup(response.content, 'html.parser')
20
+
21
+ # Find image tags (this will depend on the structure of the webpage)
22
+ image_tags = soup.find_all('img')
23
+
24
+ for img in image_tags:
25
+ img_url = img['src']
26
+ img_data = requests.get(img_url).content
27
+ img_array = np.frombuffer(img_data, np.uint8)
28
+ img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
29
+
30
+ # Check if the image is colored or line art
31
+ if len(img.shape) == 3 and img.shape[2] == 3:
32
+ # Colored image
33
+ colored_path = os.path.join(COLORED_DIR, f"{manga_title}_colored.png")
34
+ cv2.imwrite(colored_path, img)
35
+ else:
36
+ # Line art image
37
+ line_art_path = os.path.join(LINE_ART_DIR, f"{manga_title}_line_art.png")
38
+ cv2.imwrite(line_art_path, img)
39
+
40
+ # Example usage
41
+ manga_title = "example_manga"
42
+ nhentai_url = "https://nhentai.net/g/your_manga_id/"
43
+ hitomi_url = "https://hitomi.la/galleries/your_manga_id.html"
44
+
45
+ download_and_process_images(manga_title, nhentai_url)
46
+ download_and_process_images(manga_title, hitomi_url)
dataset.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ from config import config
4
+
5
+ class MangaDataset:
6
+ def __init__(self):
7
+ self.line_art_dir = config.LINE_ART_DIR
8
+ self.colored_dir = config.COLORED_DIR
9
+
10
+ def load_data(self):
11
+ """โหลดและเตรียมข้อมูล"""
12
+ line_art_paths = sorted([os.path.join(self.line_art_dir, f)
13
+ for f in os.listdir(self.line_art_dir)
14
+ if f.endswith(('.png', '.jpg'))])
15
+
16
+ colored_paths = sorted([os.path.join(self.colored_dir, f)
17
+ for f in os.listdir(self.colored_dir)
18
+ if f.endswith(('.png', '.jpg'))])
19
+
20
+ # สร้าง dataset
21
+ dataset = tf.data.Dataset.from_tensor_slices((line_art_paths, colored_paths))
22
+ dataset = dataset.map(self.process_paths, num_parallel_calls=tf.data.AUTOTUNE)
23
+ dataset = dataset.batch(config.BATCH_SIZE)
24
+ dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
25
+
26
+ return dataset
27
+
28
+ def process_paths(self, line_art_path, colored_path):
29
+ """ประมวลผลภาพคู่"""
30
+ # โหลดภาพเส้น
31
+ line_art = self.load_image(line_art_path, is_line_art=True)
32
+
33
+ # โหลดภาพสี
34
+ colored = self.load_image(colored_path, is_line_art=False)
35
+
36
+ return line_art, colored
37
+
38
+ def load_image(self, image_path, is_line_art=False):
39
+ """โหลดภาพเดียว"""
40
+ image = tf.io.read_file(image_path)
41
+ image = tf.image.decode_image(image, channels=3)
42
+ image = tf.image.resize(image, [config.IMAGE_SIZE[0], config.IMAGE_SIZE[1]])
43
+ image = tf.cast(image, tf.float32) / 255.0
44
+
45
+ if is_line_art:
46
+ # แปลงเป็น grayscale และทำให้เส้นคมชัด
47
+ image = tf.image.rgb_to_grayscale(image)
48
+ # ทำให้เส้นดำสนิท พื้นขาว
49
+ image = tf.where(image < 0.5, 0.0, 1.0)
50
+
51
+ return image
inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ import argparse
6
+ import os
7
+
8
+ class MangaColorizerInference:
9
+ def __init__(self, model_path):
10
+ """โหลดโมเดลที่ฝึกไว้แล้ว"""
11
+ try:
12
+ if not os.path.exists(model_path):
13
+ raise FileNotFoundError(f"ไม่พบไฟล์โมเดล: {model_path}")
14
+
15
+ self.model = tf.keras.models.load_model(model_path)
16
+ print(f"✅ โหลดโมเดลสำเร็จจาก: {model_path}")
17
+
18
+ except Exception as e:
19
+ print(f"❌ เกิดข้อผิดพลาดในการโหลดโมเดล: {e}")
20
+ raise
21
+
22
+ def load_and_preprocess(self, image_path):
23
+ """โหลดและเตรียมภาพสำหรับโมเดล"""
24
+ try:
25
+ # โหลดภาพ
26
+ image = cv2.imread(image_path)
27
+ if image is None:
28
+ raise ValueError(f"ไม่สามารถโหลดภาพจาก: {image_path}")
29
+
30
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
31
+
32
+ # บันทึกขนาดเดิม
33
+ original_size = image.shape[:2]
34
+
35
+ # ปรับขนาด
36
+ image_resized = cv2.resize(image, (256, 256))
37
+
38
+ # แปลงเป็น grayscale
39
+ gray = cv2.cvtColor(image_resized, cv2.COLOR_RGB2GRAY)
40
+
41
+ # ทำให้เส้นคมชัด
42
+ _, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)
43
+ binary = cv2.bitwise_not(binary)
44
+
45
+ # Normalize
46
+ binary = binary.astype(np.float32) / 255.0
47
+ binary = np.expand_dims(binary, axis=-1)
48
+
49
+ return binary, original_size, image_resized
50
+
51
+ except Exception as e:
52
+ print(f"❌ ข้อผิดพลาดในการประมวลผลภาพ: {e}")
53
+ raise
54
+
55
+ def colorize(self, image_path, output_path, save_comparison=False):
56
+ """ลงสีภาพเส้น"""
57
+ try:
58
+ # โหลดและประมวลผลภาพ
59
+ image, original_size, original_resized = self.load_and_preprocess(image_path)
60
+
61
+ # ทำนายสี
62
+ print("🎨 กำลังลงสี...")
63
+ prediction = self.model.predict(image[np.newaxis, ...], verbose=0)[0]
64
+
65
+ # แปลงกลับเป็นภาพและปรับขนาดกลับ
66
+ colored_image = (prediction * 255).astype(np.uint8)
67
+ colored_image_original_size = cv2.resize(colored_image, (original_size[1], original_size[0]))
68
+
69
+ # บันทึกผลลัพธ์
70
+ Image.fromarray(colored_image_original_size).save(output_path)
71
+ print(f"💾 บันทึกภาพที่: {output_path}")
72
+
73
+ # บันทึกภาพเปรียบเทียบ (optional)
74
+ if save_comparison:
75
+ comparison = self.create_comparison(original_resized, colored_image)
76
+ comparison_path = output_path.replace('.png', '_comparison.png')
77
+ comparison.save(comparison_path)
78
+ print(f"📊 บันทึกภาพเปรียบเทียบที่: {comparison_path}")
79
+
80
+ return colored_image_original_size
81
+
82
+ except Exception as e:
83
+ print(f"❌ ข้อผิดพลาดในการลงสี: {e}")
84
+ return None
85
+
86
+ def create_comparison(self, original, colored):
87
+ """สร้างภาพเปรียบเทียบก่อน-หลัง"""
88
+ # แปลงภาพต้นฉบับเป็น RGB สำหรับแสดง
89
+ if len(original.shape) == 3 and original.shape[2] == 3:
90
+ original_rgb = original
91
+ else:
92
+ original_rgb = cv2.cvtColor(original, cv2.COLOR_GRAY2RGB)
93
+
94
+ # รวมภาพ
95
+ comparison = np.hstack([original_rgb, colored])
96
+ return Image.fromarray(comparison)
97
+
98
+ def main():
99
+ parser = argparse.ArgumentParser(description='AI Manga Colorizer - Inference')
100
+ parser.add_argument('--input', type=str, required=True, help='Path to input line art')
101
+ parser.add_argument('--output', type=str, required=True, help='Path to save colored image')
102
+ parser.add_argument('--model', type=str, default='output/manga_colorizer.h5', help='Path to trained model')
103
+ parser.add_argument('--compare', action='store_true', help='Save comparison image')
104
+
105
+ args = parser.parse_args()
106
+
107
+ # ตรวจสอบไฟล์ input
108
+ if not os.path.exists(args.input):
109
+ print(f"❌ ไม่พบไฟล์ input: {args.input}")
110
+ return
111
+
112
+ # ใช้งานโมเดล
113
+ try:
114
+ colorizer = MangaColorizerInference(args.model)
115
+ result = colorizer.colorize(args.input, args.output, save_comparison=args.compare)
116
+
117
+ if result is not None:
118
+ print("✅ เสร็จสิ้น!")
119
+ else:
120
+ print("❌ การลงสีล้มเหลว")
121
+
122
+ except Exception as e:
123
+ print(f"❌ การทำงานล้มเหลว: {e}")
124
+
125
+ if __name__ == "__main__":
126
+ main()
model.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, Model
3
+ from config import config
4
+
5
+ def downsample(filters, size, apply_batchnorm=True):
6
+ """Downsampling block"""
7
+ initializer = tf.random_normal_initializer(0., 0.02)
8
+
9
+ result = tf.keras.Sequential()
10
+ result.add(
11
+ layers.Conv2D(filters, size, strides=2, padding='same',
12
+ kernel_initializer=initializer, use_bias=False)
13
+ )
14
+
15
+ if apply_batchnorm:
16
+ result.add(layers.BatchNormalization())
17
+
18
+ result.add(layers.LeakyReLU())
19
+
20
+ return result
21
+
22
+ def upsample(filters, size, apply_dropout=False):
23
+ """Upsampling block"""
24
+ initializer = tf.random_normal_initializer(0., 0.02)
25
+
26
+ result = tf.keras.Sequential()
27
+ result.add(
28
+ layers.Conv2DTranspose(filters, size, strides=2,
29
+ padding='same',
30
+ kernel_initializer=initializer,
31
+ use_bias=False)
32
+ )
33
+
34
+ result.add(layers.BatchNormalization())
35
+
36
+ if apply_dropout:
37
+ result.add(layers.Dropout(0.5))
38
+
39
+ result.add(layers.ReLU())
40
+
41
+ return result
42
+
43
+ def build_generator():
44
+ """สร้าง Generator แบบ U-Net"""
45
+ inputs = layers.Input(shape=[256, 256, 1])
46
+
47
+ # Encoder
48
+ down_stack = [
49
+ downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
50
+ downsample(128, 4), # (bs, 64, 64, 128)
51
+ downsample(256, 4), # (bs, 32, 32, 256)
52
+ downsample(512, 4), # (bs, 16, 16, 512)
53
+ downsample(512, 4), # (bs, 8, 8, 512)
54
+ downsample(512, 4), # (bs, 4, 4, 512)
55
+ downsample(512, 4), # (bs, 2, 2, 512)
56
+ downsample(512, 4), # (bs, 1, 1, 512)
57
+ ]
58
+
59
+ # Decoder
60
+ up_stack = [
61
+ upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
62
+ upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
63
+ upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
64
+ upsample(512, 4), # (bs, 16, 16, 1024)
65
+ upsample(256, 4), # (bs, 32, 32, 512)
66
+ upsample(128, 4), # (bs, 64, 64, 256)
67
+ upsample(64, 4), # (bs, 128, 128, 128)
68
+ ]
69
+
70
+ initializer = tf.random_normal_initializer(0., 0.02)
71
+ last = layers.Conv2DTranspose(
72
+ config.OUTPUT_CHANNELS, 4, strides=2,
73
+ padding='same', kernel_initializer=initializer,
74
+ activation='sigmoid'
75
+ ) # (bs, 256, 256, 3)
76
+
77
+ x = inputs
78
+
79
+ # Downsampling และเก็บ skip connections
80
+ skips = []
81
+ for down in down_stack:
82
+ x = down(x)
83
+ skips.append(x)
84
+
85
+ skips = reversed(skips[:-1])
86
+
87
+ # Upsampling และเชื่อม skip connections
88
+ for up, skip in zip(up_stack, skips):
89
+ x = up(x)
90
+ x = layers.Concatenate()([x, skip])
91
+
92
+ x = last(x)
93
+
94
+ return Model(inputs=inputs, outputs=x)
95
+
96
+ def build_discriminator():
97
+ """สร้าง Discriminator"""
98
+ initializer = tf.random_normal_initializer(0., 0.02)
99
+
100
+ inp = layers.Input(shape=[256, 256, 1], name='input_image')
101
+ tar = layers.Input(shape=[256, 256, 3], name='target_image')
102
+
103
+ x = layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
104
+
105
+ down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
106
+ down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
107
+ down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
108
+
109
+ zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
110
+ conv = layers.Conv2D(
111
+ 512, 4, strides=1,
112
+ kernel_initializer=initializer,
113
+ use_bias=False
114
+ )(zero_pad1) # (bs, 31, 31, 512)
115
+
116
+ batchnorm1 = layers.BatchNormalization()(conv)
117
+ leaky_relu = layers.LeakyReLU()(batchnorm1)
118
+ zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
119
+
120
+ last = layers.Conv2D(
121
+ 1, 4, strides=1,
122
+ kernel_initializer=initializer
123
+ )(zero_pad2) # (bs, 30, 30, 1)
124
+
125
+ return Model(inputs=[inp, tar], outputs=last)
126
+
127
+ class MangaColorizer(Model):
128
+ """คลาสหลักสำหรับการลงสีมังงะ"""
129
+ def __init__(self):
130
+ super().__init__()
131
+ self.generator = build_generator()
132
+ self.discriminator = build_discriminator()
133
+
134
+ def compile(self, g_optimizer, d_optimizer, loss_fn):
135
+ super().compile()
136
+ self.g_optimizer = g_optimizer
137
+ self.d_optimizer = d_optimizer
138
+ self.loss_fn = loss_fn
139
+
140
+ def train_step(self, data):
141
+ input_image, target_image = data
142
+
143
+ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
144
+ # Generator สร้างภาพ
145
+ gen_output = self.generator(input_image, training=True)
146
+
147
+ # Discriminator ตรวจสอบ
148
+ disc_real_output = self.discriminator([input_image, target_image], training=True)
149
+ disc_generated_output = self.discriminator([input_image, gen_output], training=True)
150
+
151
+ # คำนวณ loss
152
+ gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(
153
+ disc_generated_output, gen_output, target_image
154
+ )
155
+ disc_loss = self.discriminator_loss(disc_real_output, disc_generated_output)
156
+
157
+ # อัพเดท gradient
158
+ generator_gradients = gen_tape.gradient(
159
+ gen_total_loss, self.generator.trainable_variables
160
+ )
161
+ discriminator_gradients = disc_tape.gradient(
162
+ disc_loss, self.discriminator.trainable_variables
163
+ )
164
+
165
+ self.g_optimizer.apply_gradients(
166
+ zip(generator_gradients, self.generator.trainable_variables)
167
+ )
168
+ self.d_optimizer.apply_gradients(
169
+ zip(discriminator_gradients, self.discriminator.trainable_variables)
170
+ )
171
+
172
+ return {
173
+ "gen_total_loss": gen_total_loss,
174
+ "gen_gan_loss": gen_gan_loss,
175
+ "gen_l1_loss": gen_l1_loss,
176
+ "disc_loss": disc_loss
177
+ }
178
+
179
+ def generator_loss(self, disc_generated_output, gen_output, target):
180
+ gan_loss = self.loss_fn(tf.ones_like(disc_generated_output), disc_generated_output)
181
+
182
+ # L1 loss - ทำให้สีใกล้เคียงกับภาพจริง
183
+ l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
184
+
185
+ total_gen_loss = gan_loss + (100 * l1_loss) # L1 weight = 100
186
+
187
+ return total_gen_loss, gan_loss, l1_loss
188
+
189
+ def discriminator_loss(self, disc_real_output, disc_generated_output):
190
+ real_loss = self.loss_fn(tf.ones_like(disc_real_output), disc_real_output)
191
+ generated_loss = self.loss_fn(tf.zeros_like(disc_generated_output), disc_generated_output)
192
+
193
+ total_disc_loss = real_loss + generated_loss
194
+
195
+ return total_disc_loss
train.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from model import MangaColorizer
3
+ from dataset import MangaDataset
4
+ from utils import save_comparison, prepare_directories, check_dataset_size
5
+ from config import config
6
+ import matplotlib.pyplot as plt
7
+ import os
8
+
9
+ class TrainingMonitor(tf.keras.callbacks.Callback):
10
+ """Callback สำหรับตรวจสอบการฝึก"""
11
+ def __init__(self, dataset):
12
+ super().__init__()
13
+ self.dataset = dataset
14
+
15
+ def on_epoch_end(self, epoch, logs=None):
16
+ # ใช้ภาพตัวอย่างจาก batch แรก
17
+ for input_image, target_image in self.dataset.take(1):
18
+ prediction = self.model.generator(input_image, training=False)
19
+ save_comparison(input_image, prediction, epoch, 0)
20
+
21
+ # พิมพ์ loss ทุก epoch
22
+ if logs:
23
+ print(f"📊 Epoch {epoch+1}: "
24
+ f"Gen Loss: {logs.get('gen_total_loss', 0):.4f}, "
25
+ f"Disc Loss: {logs.get('disc_loss', 0):.4f}")
26
+
27
+ def main():
28
+ # เตรียม directory
29
+ prepare_directories()
30
+
31
+ # ตรวจสอบ dataset
32
+ if not check_dataset_size(config.LINE_ART_DIR, config.COLORED_DIR):
33
+ print("❌ กรุณาตรวจสอบ dataset ก่อนเริ่มฝึก")
34
+ return
35
+
36
+ # โหลดข้อมูล
37
+ print("🔄 กำลังโหลดข้อมูล...")
38
+ dataset = MangaDataset().load_data()
39
+
40
+ # ตรวจสอบว่ามีข้อมูลหรือไม่
41
+ try:
42
+ sample_batch = next(iter(dataset))
43
+ print(f"✅ โหลดข้อมูลสำเร็จ: Batch size {sample_batch[0].shape}")
44
+ except StopIteration:
45
+ print("❌ ไม่มีข้อมูลใน dataset")
46
+ return
47
+
48
+ # สร้างโมเดล
49
+ print("🔄 กำลังสร้างโมเดล...")
50
+ colorizer = MangaColorizer()
51
+
52
+ # Compile โมเดล
53
+ generator_optimizer = tf.keras.optimizers.Adam(config.LEARNING_RATE, beta_1=0.5)
54
+ discriminator_optimizer = tf.keras.optimizers.Adam(config.LEARNING_RATE, beta_1=0.5)
55
+ loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
56
+
57
+ colorizer.compile(
58
+ g_optimizer=generator_optimizer,
59
+ d_optimizer=discriminator_optimizer,
60
+ loss_fn=loss_fn
61
+ )
62
+
63
+ print(f"✅ สร้างโมเดลสำเร็จ")
64
+ print(f"📈 เริ่มฝึก {config.EPOCHS} epochs...")
65
+
66
+ # ฝึกโมเดล
67
+ history = colorizer.fit(
68
+ dataset,
69
+ epochs=config.EPOCHS,
70
+ callbacks=[TrainingMonitor(dataset)],
71
+ verbose=1
72
+ )
73
+
74
+ # บันทึกโมเดล
75
+ os.makedirs('output', exist_ok=True)
76
+ colorizer.generator.save('output/manga_colorizer.h5')
77
+ print("✅ บันทึกโมเดลเรียบร้อย: output/manga_colorizer.h5")
78
+
79
+ # พล็อตกราฟ loss
80
+ try:
81
+ plt.figure(figsize=(12, 4))
82
+
83
+ plt.subplot(1, 2, 1)
84
+ if 'gen_total_loss' in history.history:
85
+ plt.plot(history.history['gen_total_loss'], label='Generator Total Loss')
86
+ if 'gen_gan_loss' in history.history:
87
+ plt.plot(history.history['gen_gan_loss'], label='Generator GAN Loss', linestyle='--')
88
+ if 'gen_l1_loss' in history.history:
89
+ plt.plot(history.history['gen_l1_loss'], label='Generator L1 Loss', linestyle=':')
90
+ plt.title('Generator Loss')
91
+ plt.xlabel('Epoch')
92
+ plt.ylabel('Loss')
93
+ plt.legend()
94
+
95
+ plt.subplot(1, 2, 2)
96
+ if 'disc_loss' in history.history:
97
+ plt.plot(history.history['disc_loss'], label='Discriminator Loss', color='red')
98
+ plt.title('Discriminator Loss')
99
+ plt.xlabel('Epoch')
100
+ plt.ylabel('Loss')
101
+ plt.legend()
102
+
103
+ plt.tight_layout()
104
+ plt.savefig('output/training_loss.png', dpi=300, bbox_inches='tight')
105
+ plt.close()
106
+ print("✅ บันทึกกราฟ training loss: output/training_loss.png")
107
+
108
+ except Exception as e:
109
+ print(f"❌ ไม่สามารถบันทึกกราฟ loss: {e}")
110
+
111
+ if __name__ == "__main__":
112
+ main()
utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import Image
5
+ import os
6
+
7
+ def load_image(image_path, is_line_art=False):
8
+ """โหลดภาพและประมวลผลเบื้องต้น"""
9
+ image = tf.io.read_file(image_path)
10
+ image = tf.image.decode_image(image, channels=3)
11
+ image = tf.image.resize(image, [256, 256])
12
+ image = tf.cast(image, tf.float32) / 255.0
13
+
14
+ if is_line_art:
15
+ # แปลงเป็น grayscale และทำให้เส้นคมชัด
16
+ image = tf.image.rgb_to_grayscale(image)
17
+ # ทำให้เส้นดำสนิท พื้นขาว
18
+ image = tf.where(image < 0.5, 0.0, 1.0)
19
+
20
+ return image
21
+
22
+ def extract_line_art_from_colored(colored_image):
23
+ """สกัดเส้นจากภาพสี (ใช้สร้าง training data)"""
24
+ # แปลงเป็น grayscale
25
+ gray = cv2.cvtColor(colored_image, cv2.COLOR_RGB2GRAY)
26
+
27
+ # ใช้ adaptive threshold เพื่อให้เส้นคมชัด
28
+ line_art = cv2.adaptiveThreshold(
29
+ gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
30
+ cv2.THRESH_BINARY_INV, 11, 2
31
+ )
32
+
33
+ # ลบ noise
34
+ kernel = np.ones((2, 2), np.uint8)
35
+ line_art = cv2.morphologyEx(line_art, cv2.MORPH_OPEN, kernel)
36
+
37
+ return line_art
38
+
39
+ def save_comparison(image, prediction, epoch, step):
40
+ """บันทึกภาพเปรียบเทียบ - แก้ไขเวอร์ชัน"""
41
+ try:
42
+ # แปลง TensorFlow tensor เป็น numpy array
43
+ input_image = image[0].numpy() # (256, 256, 1)
44
+ pred_image = prediction[0].numpy() # (256, 256, 3)
45
+
46
+ # แปลงภาพ input จาก grayscale เป็น RGB
47
+ if input_image.shape[-1] == 1:
48
+ input_image_rgb = np.repeat(input_image, 3, axis=-1)
49
+ else:
50
+ input_image_rgb = input_image
51
+
52
+ # รวมภาพ input และ prediction ข้างกัน
53
+ combined = np.concatenate([input_image_rgb, pred_image], axis=1)
54
+
55
+ # คลิปค่าและแปลงเป็น uint8
56
+ combined = np.clip(combined, 0, 1)
57
+ combined = (combined * 255).astype(np.uint8)
58
+
59
+ # บันทึกภาพ
60
+ comparison = Image.fromarray(combined)
61
+ os.makedirs('output', exist_ok=True)
62
+ comparison.save(f'output/epoch_{epoch:03d}_step_{step:03d}.png')
63
+ print(f"💾 บันทึกภาพเปรียบเทียบ: output/epoch_{epoch:03d}_step_{step:03d}.png")
64
+
65
+ except Exception as e:
66
+ print(f"❌ ข้อผิดพลาดในการบันทึกภาพเปรียบเทียบ: {e}")
67
+
68
+ def prepare_directories():
69
+ """สร้าง directory ที่จำเป็น"""
70
+ os.makedirs('train_images/line_arts', exist_ok=True)
71
+ os.makedirs('train_images/colored', exist_ok=True)
72
+ os.makedirs('output', exist_ok=True)
73
+ print("✅ สร้าง directory เรียบร้อย")
74
+
75
+ # ฟังก์ชันเพิ่มเติมสำหรับการตรวจสอบข้อมูล
76
+ def check_dataset_size(line_art_dir, colored_dir):
77
+ """ตรวจสอบจำนวนไฟล์ใน dataset"""
78
+ line_art_files = [f for f in os.listdir(line_art_dir) if f.endswith(('.png', '.jpg'))]
79
+ colored_files = [f for f in os.listdir(colored_dir) if f.endswith(('.png', '.jpg'))]
80
+
81
+ print(f"📁 ภาพเส้น: {len(line_art_files)} ไฟล์")
82
+ print(f"🎨 ภาพสี: {len(colored_files)} ไฟล์")
83
+
84
+ if len(line_art_files) != len(colored_files):
85
+ print("⚠️ จำนวนภาพเส้นและภาพสีไม่เท่ากัน!")
86
+ return False
87
+ return True