Upload 6 files
Browse files- config.py +46 -0
- dataset.py +51 -0
- inference.py +126 -0
- model.py +195 -0
- train.py +112 -0
- utils.py +87 -0
config.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from bs4 import BeautifulSoup
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
# Define the directories
|
| 9 |
+
LINE_ART_DIR = "train_images/line_arts"
|
| 10 |
+
COLORED_DIR = "train_images/colored"
|
| 11 |
+
|
| 12 |
+
# Ensure directories exist
|
| 13 |
+
os.makedirs(LINE_ART_DIR, exist_ok=True)
|
| 14 |
+
os.makedirs(COLORED_DIR, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
# Function to download and process images
|
| 17 |
+
def download_and_process_images(manga_title, url):
|
| 18 |
+
response = requests.get(url)
|
| 19 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
| 20 |
+
|
| 21 |
+
# Find image tags (this will depend on the structure of the webpage)
|
| 22 |
+
image_tags = soup.find_all('img')
|
| 23 |
+
|
| 24 |
+
for img in image_tags:
|
| 25 |
+
img_url = img['src']
|
| 26 |
+
img_data = requests.get(img_url).content
|
| 27 |
+
img_array = np.frombuffer(img_data, np.uint8)
|
| 28 |
+
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
| 29 |
+
|
| 30 |
+
# Check if the image is colored or line art
|
| 31 |
+
if len(img.shape) == 3 and img.shape[2] == 3:
|
| 32 |
+
# Colored image
|
| 33 |
+
colored_path = os.path.join(COLORED_DIR, f"{manga_title}_colored.png")
|
| 34 |
+
cv2.imwrite(colored_path, img)
|
| 35 |
+
else:
|
| 36 |
+
# Line art image
|
| 37 |
+
line_art_path = os.path.join(LINE_ART_DIR, f"{manga_title}_line_art.png")
|
| 38 |
+
cv2.imwrite(line_art_path, img)
|
| 39 |
+
|
| 40 |
+
# Example usage
|
| 41 |
+
manga_title = "example_manga"
|
| 42 |
+
nhentai_url = "https://nhentai.net/g/your_manga_id/"
|
| 43 |
+
hitomi_url = "https://hitomi.la/galleries/your_manga_id.html"
|
| 44 |
+
|
| 45 |
+
download_and_process_images(manga_title, nhentai_url)
|
| 46 |
+
download_and_process_images(manga_title, hitomi_url)
|
dataset.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import os
|
| 3 |
+
from config import config
|
| 4 |
+
|
| 5 |
+
class MangaDataset:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.line_art_dir = config.LINE_ART_DIR
|
| 8 |
+
self.colored_dir = config.COLORED_DIR
|
| 9 |
+
|
| 10 |
+
def load_data(self):
|
| 11 |
+
"""โหลดและเตรียมข้อมูล"""
|
| 12 |
+
line_art_paths = sorted([os.path.join(self.line_art_dir, f)
|
| 13 |
+
for f in os.listdir(self.line_art_dir)
|
| 14 |
+
if f.endswith(('.png', '.jpg'))])
|
| 15 |
+
|
| 16 |
+
colored_paths = sorted([os.path.join(self.colored_dir, f)
|
| 17 |
+
for f in os.listdir(self.colored_dir)
|
| 18 |
+
if f.endswith(('.png', '.jpg'))])
|
| 19 |
+
|
| 20 |
+
# สร้าง dataset
|
| 21 |
+
dataset = tf.data.Dataset.from_tensor_slices((line_art_paths, colored_paths))
|
| 22 |
+
dataset = dataset.map(self.process_paths, num_parallel_calls=tf.data.AUTOTUNE)
|
| 23 |
+
dataset = dataset.batch(config.BATCH_SIZE)
|
| 24 |
+
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
|
| 25 |
+
|
| 26 |
+
return dataset
|
| 27 |
+
|
| 28 |
+
def process_paths(self, line_art_path, colored_path):
|
| 29 |
+
"""ประมวลผลภาพคู่"""
|
| 30 |
+
# โหลดภาพเส้น
|
| 31 |
+
line_art = self.load_image(line_art_path, is_line_art=True)
|
| 32 |
+
|
| 33 |
+
# โหลดภาพสี
|
| 34 |
+
colored = self.load_image(colored_path, is_line_art=False)
|
| 35 |
+
|
| 36 |
+
return line_art, colored
|
| 37 |
+
|
| 38 |
+
def load_image(self, image_path, is_line_art=False):
|
| 39 |
+
"""โหลดภาพเดียว"""
|
| 40 |
+
image = tf.io.read_file(image_path)
|
| 41 |
+
image = tf.image.decode_image(image, channels=3)
|
| 42 |
+
image = tf.image.resize(image, [config.IMAGE_SIZE[0], config.IMAGE_SIZE[1]])
|
| 43 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 44 |
+
|
| 45 |
+
if is_line_art:
|
| 46 |
+
# แปลงเป็น grayscale และทำให้เส้นคมชัด
|
| 47 |
+
image = tf.image.rgb_to_grayscale(image)
|
| 48 |
+
# ทำให้เส้นดำสนิท พื้นขาว
|
| 49 |
+
image = tf.where(image < 0.5, 0.0, 1.0)
|
| 50 |
+
|
| 51 |
+
return image
|
inference.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class MangaColorizerInference:
|
| 9 |
+
def __init__(self, model_path):
|
| 10 |
+
"""โหลดโมเดลที่ฝึกไว้แล้ว"""
|
| 11 |
+
try:
|
| 12 |
+
if not os.path.exists(model_path):
|
| 13 |
+
raise FileNotFoundError(f"ไม่พบไฟล์โมเดล: {model_path}")
|
| 14 |
+
|
| 15 |
+
self.model = tf.keras.models.load_model(model_path)
|
| 16 |
+
print(f"✅ โหลดโมเดลสำเร็จจาก: {model_path}")
|
| 17 |
+
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"❌ เกิดข้อผิดพลาดในการโหลดโมเดล: {e}")
|
| 20 |
+
raise
|
| 21 |
+
|
| 22 |
+
def load_and_preprocess(self, image_path):
|
| 23 |
+
"""โหลดและเตรียมภาพสำหรับโมเดล"""
|
| 24 |
+
try:
|
| 25 |
+
# โหลดภาพ
|
| 26 |
+
image = cv2.imread(image_path)
|
| 27 |
+
if image is None:
|
| 28 |
+
raise ValueError(f"ไม่สามารถโหลดภาพจาก: {image_path}")
|
| 29 |
+
|
| 30 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 31 |
+
|
| 32 |
+
# บันทึกขนาดเดิม
|
| 33 |
+
original_size = image.shape[:2]
|
| 34 |
+
|
| 35 |
+
# ปรับขนาด
|
| 36 |
+
image_resized = cv2.resize(image, (256, 256))
|
| 37 |
+
|
| 38 |
+
# แปลงเป็น grayscale
|
| 39 |
+
gray = cv2.cvtColor(image_resized, cv2.COLOR_RGB2GRAY)
|
| 40 |
+
|
| 41 |
+
# ทำให้เส้นคมชัด
|
| 42 |
+
_, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)
|
| 43 |
+
binary = cv2.bitwise_not(binary)
|
| 44 |
+
|
| 45 |
+
# Normalize
|
| 46 |
+
binary = binary.astype(np.float32) / 255.0
|
| 47 |
+
binary = np.expand_dims(binary, axis=-1)
|
| 48 |
+
|
| 49 |
+
return binary, original_size, image_resized
|
| 50 |
+
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"❌ ข้อผิดพลาดในการประมวลผลภาพ: {e}")
|
| 53 |
+
raise
|
| 54 |
+
|
| 55 |
+
def colorize(self, image_path, output_path, save_comparison=False):
|
| 56 |
+
"""ลงสีภาพเส้น"""
|
| 57 |
+
try:
|
| 58 |
+
# โหลดและประมวลผลภาพ
|
| 59 |
+
image, original_size, original_resized = self.load_and_preprocess(image_path)
|
| 60 |
+
|
| 61 |
+
# ทำนายสี
|
| 62 |
+
print("🎨 กำลังลงสี...")
|
| 63 |
+
prediction = self.model.predict(image[np.newaxis, ...], verbose=0)[0]
|
| 64 |
+
|
| 65 |
+
# แปลงกลับเป็นภาพและปรับขนาดกลับ
|
| 66 |
+
colored_image = (prediction * 255).astype(np.uint8)
|
| 67 |
+
colored_image_original_size = cv2.resize(colored_image, (original_size[1], original_size[0]))
|
| 68 |
+
|
| 69 |
+
# บันทึกผลลัพธ์
|
| 70 |
+
Image.fromarray(colored_image_original_size).save(output_path)
|
| 71 |
+
print(f"💾 บันทึกภาพที่: {output_path}")
|
| 72 |
+
|
| 73 |
+
# บันทึกภาพเปรียบเทียบ (optional)
|
| 74 |
+
if save_comparison:
|
| 75 |
+
comparison = self.create_comparison(original_resized, colored_image)
|
| 76 |
+
comparison_path = output_path.replace('.png', '_comparison.png')
|
| 77 |
+
comparison.save(comparison_path)
|
| 78 |
+
print(f"📊 บันทึกภาพเปรียบเทียบที่: {comparison_path}")
|
| 79 |
+
|
| 80 |
+
return colored_image_original_size
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"❌ ข้อผิดพลาดในการลงสี: {e}")
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
def create_comparison(self, original, colored):
|
| 87 |
+
"""สร้างภาพเปรียบเทียบก่อน-หลัง"""
|
| 88 |
+
# แปลงภาพต้นฉบับเป็น RGB สำหรับแสดง
|
| 89 |
+
if len(original.shape) == 3 and original.shape[2] == 3:
|
| 90 |
+
original_rgb = original
|
| 91 |
+
else:
|
| 92 |
+
original_rgb = cv2.cvtColor(original, cv2.COLOR_GRAY2RGB)
|
| 93 |
+
|
| 94 |
+
# รวมภาพ
|
| 95 |
+
comparison = np.hstack([original_rgb, colored])
|
| 96 |
+
return Image.fromarray(comparison)
|
| 97 |
+
|
| 98 |
+
def main():
|
| 99 |
+
parser = argparse.ArgumentParser(description='AI Manga Colorizer - Inference')
|
| 100 |
+
parser.add_argument('--input', type=str, required=True, help='Path to input line art')
|
| 101 |
+
parser.add_argument('--output', type=str, required=True, help='Path to save colored image')
|
| 102 |
+
parser.add_argument('--model', type=str, default='output/manga_colorizer.h5', help='Path to trained model')
|
| 103 |
+
parser.add_argument('--compare', action='store_true', help='Save comparison image')
|
| 104 |
+
|
| 105 |
+
args = parser.parse_args()
|
| 106 |
+
|
| 107 |
+
# ตรวจสอบไฟล์ input
|
| 108 |
+
if not os.path.exists(args.input):
|
| 109 |
+
print(f"❌ ไม่พบไฟล์ input: {args.input}")
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
# ใช้งานโมเดล
|
| 113 |
+
try:
|
| 114 |
+
colorizer = MangaColorizerInference(args.model)
|
| 115 |
+
result = colorizer.colorize(args.input, args.output, save_comparison=args.compare)
|
| 116 |
+
|
| 117 |
+
if result is not None:
|
| 118 |
+
print("✅ เสร็จสิ้น!")
|
| 119 |
+
else:
|
| 120 |
+
print("❌ การลงสีล้มเหลว")
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"❌ การทำงานล้มเหลว: {e}")
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
main()
|
model.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers, Model
|
| 3 |
+
from config import config
|
| 4 |
+
|
| 5 |
+
def downsample(filters, size, apply_batchnorm=True):
|
| 6 |
+
"""Downsampling block"""
|
| 7 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 8 |
+
|
| 9 |
+
result = tf.keras.Sequential()
|
| 10 |
+
result.add(
|
| 11 |
+
layers.Conv2D(filters, size, strides=2, padding='same',
|
| 12 |
+
kernel_initializer=initializer, use_bias=False)
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
if apply_batchnorm:
|
| 16 |
+
result.add(layers.BatchNormalization())
|
| 17 |
+
|
| 18 |
+
result.add(layers.LeakyReLU())
|
| 19 |
+
|
| 20 |
+
return result
|
| 21 |
+
|
| 22 |
+
def upsample(filters, size, apply_dropout=False):
|
| 23 |
+
"""Upsampling block"""
|
| 24 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 25 |
+
|
| 26 |
+
result = tf.keras.Sequential()
|
| 27 |
+
result.add(
|
| 28 |
+
layers.Conv2DTranspose(filters, size, strides=2,
|
| 29 |
+
padding='same',
|
| 30 |
+
kernel_initializer=initializer,
|
| 31 |
+
use_bias=False)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
result.add(layers.BatchNormalization())
|
| 35 |
+
|
| 36 |
+
if apply_dropout:
|
| 37 |
+
result.add(layers.Dropout(0.5))
|
| 38 |
+
|
| 39 |
+
result.add(layers.ReLU())
|
| 40 |
+
|
| 41 |
+
return result
|
| 42 |
+
|
| 43 |
+
def build_generator():
|
| 44 |
+
"""สร้าง Generator แบบ U-Net"""
|
| 45 |
+
inputs = layers.Input(shape=[256, 256, 1])
|
| 46 |
+
|
| 47 |
+
# Encoder
|
| 48 |
+
down_stack = [
|
| 49 |
+
downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
|
| 50 |
+
downsample(128, 4), # (bs, 64, 64, 128)
|
| 51 |
+
downsample(256, 4), # (bs, 32, 32, 256)
|
| 52 |
+
downsample(512, 4), # (bs, 16, 16, 512)
|
| 53 |
+
downsample(512, 4), # (bs, 8, 8, 512)
|
| 54 |
+
downsample(512, 4), # (bs, 4, 4, 512)
|
| 55 |
+
downsample(512, 4), # (bs, 2, 2, 512)
|
| 56 |
+
downsample(512, 4), # (bs, 1, 1, 512)
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
# Decoder
|
| 60 |
+
up_stack = [
|
| 61 |
+
upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
|
| 62 |
+
upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
|
| 63 |
+
upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
|
| 64 |
+
upsample(512, 4), # (bs, 16, 16, 1024)
|
| 65 |
+
upsample(256, 4), # (bs, 32, 32, 512)
|
| 66 |
+
upsample(128, 4), # (bs, 64, 64, 256)
|
| 67 |
+
upsample(64, 4), # (bs, 128, 128, 128)
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 71 |
+
last = layers.Conv2DTranspose(
|
| 72 |
+
config.OUTPUT_CHANNELS, 4, strides=2,
|
| 73 |
+
padding='same', kernel_initializer=initializer,
|
| 74 |
+
activation='sigmoid'
|
| 75 |
+
) # (bs, 256, 256, 3)
|
| 76 |
+
|
| 77 |
+
x = inputs
|
| 78 |
+
|
| 79 |
+
# Downsampling และเก็บ skip connections
|
| 80 |
+
skips = []
|
| 81 |
+
for down in down_stack:
|
| 82 |
+
x = down(x)
|
| 83 |
+
skips.append(x)
|
| 84 |
+
|
| 85 |
+
skips = reversed(skips[:-1])
|
| 86 |
+
|
| 87 |
+
# Upsampling และเชื่อม skip connections
|
| 88 |
+
for up, skip in zip(up_stack, skips):
|
| 89 |
+
x = up(x)
|
| 90 |
+
x = layers.Concatenate()([x, skip])
|
| 91 |
+
|
| 92 |
+
x = last(x)
|
| 93 |
+
|
| 94 |
+
return Model(inputs=inputs, outputs=x)
|
| 95 |
+
|
| 96 |
+
def build_discriminator():
|
| 97 |
+
"""สร้าง Discriminator"""
|
| 98 |
+
initializer = tf.random_normal_initializer(0., 0.02)
|
| 99 |
+
|
| 100 |
+
inp = layers.Input(shape=[256, 256, 1], name='input_image')
|
| 101 |
+
tar = layers.Input(shape=[256, 256, 3], name='target_image')
|
| 102 |
+
|
| 103 |
+
x = layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
|
| 104 |
+
|
| 105 |
+
down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
|
| 106 |
+
down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
|
| 107 |
+
down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
|
| 108 |
+
|
| 109 |
+
zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
|
| 110 |
+
conv = layers.Conv2D(
|
| 111 |
+
512, 4, strides=1,
|
| 112 |
+
kernel_initializer=initializer,
|
| 113 |
+
use_bias=False
|
| 114 |
+
)(zero_pad1) # (bs, 31, 31, 512)
|
| 115 |
+
|
| 116 |
+
batchnorm1 = layers.BatchNormalization()(conv)
|
| 117 |
+
leaky_relu = layers.LeakyReLU()(batchnorm1)
|
| 118 |
+
zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
|
| 119 |
+
|
| 120 |
+
last = layers.Conv2D(
|
| 121 |
+
1, 4, strides=1,
|
| 122 |
+
kernel_initializer=initializer
|
| 123 |
+
)(zero_pad2) # (bs, 30, 30, 1)
|
| 124 |
+
|
| 125 |
+
return Model(inputs=[inp, tar], outputs=last)
|
| 126 |
+
|
| 127 |
+
class MangaColorizer(Model):
|
| 128 |
+
"""คลาสหลักสำหรับการลงสีมังงะ"""
|
| 129 |
+
def __init__(self):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.generator = build_generator()
|
| 132 |
+
self.discriminator = build_discriminator()
|
| 133 |
+
|
| 134 |
+
def compile(self, g_optimizer, d_optimizer, loss_fn):
|
| 135 |
+
super().compile()
|
| 136 |
+
self.g_optimizer = g_optimizer
|
| 137 |
+
self.d_optimizer = d_optimizer
|
| 138 |
+
self.loss_fn = loss_fn
|
| 139 |
+
|
| 140 |
+
def train_step(self, data):
|
| 141 |
+
input_image, target_image = data
|
| 142 |
+
|
| 143 |
+
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
|
| 144 |
+
# Generator สร้างภาพ
|
| 145 |
+
gen_output = self.generator(input_image, training=True)
|
| 146 |
+
|
| 147 |
+
# Discriminator ตรวจสอบ
|
| 148 |
+
disc_real_output = self.discriminator([input_image, target_image], training=True)
|
| 149 |
+
disc_generated_output = self.discriminator([input_image, gen_output], training=True)
|
| 150 |
+
|
| 151 |
+
# คำนวณ loss
|
| 152 |
+
gen_total_loss, gen_gan_loss, gen_l1_loss = self.generator_loss(
|
| 153 |
+
disc_generated_output, gen_output, target_image
|
| 154 |
+
)
|
| 155 |
+
disc_loss = self.discriminator_loss(disc_real_output, disc_generated_output)
|
| 156 |
+
|
| 157 |
+
# อัพเดท gradient
|
| 158 |
+
generator_gradients = gen_tape.gradient(
|
| 159 |
+
gen_total_loss, self.generator.trainable_variables
|
| 160 |
+
)
|
| 161 |
+
discriminator_gradients = disc_tape.gradient(
|
| 162 |
+
disc_loss, self.discriminator.trainable_variables
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.g_optimizer.apply_gradients(
|
| 166 |
+
zip(generator_gradients, self.generator.trainable_variables)
|
| 167 |
+
)
|
| 168 |
+
self.d_optimizer.apply_gradients(
|
| 169 |
+
zip(discriminator_gradients, self.discriminator.trainable_variables)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"gen_total_loss": gen_total_loss,
|
| 174 |
+
"gen_gan_loss": gen_gan_loss,
|
| 175 |
+
"gen_l1_loss": gen_l1_loss,
|
| 176 |
+
"disc_loss": disc_loss
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
def generator_loss(self, disc_generated_output, gen_output, target):
|
| 180 |
+
gan_loss = self.loss_fn(tf.ones_like(disc_generated_output), disc_generated_output)
|
| 181 |
+
|
| 182 |
+
# L1 loss - ทำให้สีใกล้เคียงกับภาพจริง
|
| 183 |
+
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
|
| 184 |
+
|
| 185 |
+
total_gen_loss = gan_loss + (100 * l1_loss) # L1 weight = 100
|
| 186 |
+
|
| 187 |
+
return total_gen_loss, gan_loss, l1_loss
|
| 188 |
+
|
| 189 |
+
def discriminator_loss(self, disc_real_output, disc_generated_output):
|
| 190 |
+
real_loss = self.loss_fn(tf.ones_like(disc_real_output), disc_real_output)
|
| 191 |
+
generated_loss = self.loss_fn(tf.zeros_like(disc_generated_output), disc_generated_output)
|
| 192 |
+
|
| 193 |
+
total_disc_loss = real_loss + generated_loss
|
| 194 |
+
|
| 195 |
+
return total_disc_loss
|
train.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from model import MangaColorizer
|
| 3 |
+
from dataset import MangaDataset
|
| 4 |
+
from utils import save_comparison, prepare_directories, check_dataset_size
|
| 5 |
+
from config import config
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
class TrainingMonitor(tf.keras.callbacks.Callback):
|
| 10 |
+
"""Callback สำหรับตรวจสอบการฝึก"""
|
| 11 |
+
def __init__(self, dataset):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.dataset = dataset
|
| 14 |
+
|
| 15 |
+
def on_epoch_end(self, epoch, logs=None):
|
| 16 |
+
# ใช้ภาพตัวอย่างจาก batch แรก
|
| 17 |
+
for input_image, target_image in self.dataset.take(1):
|
| 18 |
+
prediction = self.model.generator(input_image, training=False)
|
| 19 |
+
save_comparison(input_image, prediction, epoch, 0)
|
| 20 |
+
|
| 21 |
+
# พิมพ์ loss ทุก epoch
|
| 22 |
+
if logs:
|
| 23 |
+
print(f"📊 Epoch {epoch+1}: "
|
| 24 |
+
f"Gen Loss: {logs.get('gen_total_loss', 0):.4f}, "
|
| 25 |
+
f"Disc Loss: {logs.get('disc_loss', 0):.4f}")
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
# เตรียม directory
|
| 29 |
+
prepare_directories()
|
| 30 |
+
|
| 31 |
+
# ตรวจสอบ dataset
|
| 32 |
+
if not check_dataset_size(config.LINE_ART_DIR, config.COLORED_DIR):
|
| 33 |
+
print("❌ กรุณาตรวจสอบ dataset ก่อนเริ่มฝึก")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
# โหลดข้อมูล
|
| 37 |
+
print("🔄 กำลังโหลดข้อมูล...")
|
| 38 |
+
dataset = MangaDataset().load_data()
|
| 39 |
+
|
| 40 |
+
# ตรวจสอบว่ามีข้อมูลหรือไม่
|
| 41 |
+
try:
|
| 42 |
+
sample_batch = next(iter(dataset))
|
| 43 |
+
print(f"✅ โหลดข้อมูลสำเร็จ: Batch size {sample_batch[0].shape}")
|
| 44 |
+
except StopIteration:
|
| 45 |
+
print("❌ ไม่มีข้อมูลใน dataset")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
# สร้างโมเดล
|
| 49 |
+
print("🔄 กำลังสร้างโมเดล...")
|
| 50 |
+
colorizer = MangaColorizer()
|
| 51 |
+
|
| 52 |
+
# Compile โมเดล
|
| 53 |
+
generator_optimizer = tf.keras.optimizers.Adam(config.LEARNING_RATE, beta_1=0.5)
|
| 54 |
+
discriminator_optimizer = tf.keras.optimizers.Adam(config.LEARNING_RATE, beta_1=0.5)
|
| 55 |
+
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
|
| 56 |
+
|
| 57 |
+
colorizer.compile(
|
| 58 |
+
g_optimizer=generator_optimizer,
|
| 59 |
+
d_optimizer=discriminator_optimizer,
|
| 60 |
+
loss_fn=loss_fn
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
print(f"✅ สร้างโมเดลสำเร็จ")
|
| 64 |
+
print(f"📈 เริ่มฝึก {config.EPOCHS} epochs...")
|
| 65 |
+
|
| 66 |
+
# ฝึกโมเดล
|
| 67 |
+
history = colorizer.fit(
|
| 68 |
+
dataset,
|
| 69 |
+
epochs=config.EPOCHS,
|
| 70 |
+
callbacks=[TrainingMonitor(dataset)],
|
| 71 |
+
verbose=1
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# บันทึกโมเดล
|
| 75 |
+
os.makedirs('output', exist_ok=True)
|
| 76 |
+
colorizer.generator.save('output/manga_colorizer.h5')
|
| 77 |
+
print("✅ บันทึกโมเดลเรียบร้อย: output/manga_colorizer.h5")
|
| 78 |
+
|
| 79 |
+
# พล็อตกราฟ loss
|
| 80 |
+
try:
|
| 81 |
+
plt.figure(figsize=(12, 4))
|
| 82 |
+
|
| 83 |
+
plt.subplot(1, 2, 1)
|
| 84 |
+
if 'gen_total_loss' in history.history:
|
| 85 |
+
plt.plot(history.history['gen_total_loss'], label='Generator Total Loss')
|
| 86 |
+
if 'gen_gan_loss' in history.history:
|
| 87 |
+
plt.plot(history.history['gen_gan_loss'], label='Generator GAN Loss', linestyle='--')
|
| 88 |
+
if 'gen_l1_loss' in history.history:
|
| 89 |
+
plt.plot(history.history['gen_l1_loss'], label='Generator L1 Loss', linestyle=':')
|
| 90 |
+
plt.title('Generator Loss')
|
| 91 |
+
plt.xlabel('Epoch')
|
| 92 |
+
plt.ylabel('Loss')
|
| 93 |
+
plt.legend()
|
| 94 |
+
|
| 95 |
+
plt.subplot(1, 2, 2)
|
| 96 |
+
if 'disc_loss' in history.history:
|
| 97 |
+
plt.plot(history.history['disc_loss'], label='Discriminator Loss', color='red')
|
| 98 |
+
plt.title('Discriminator Loss')
|
| 99 |
+
plt.xlabel('Epoch')
|
| 100 |
+
plt.ylabel('Loss')
|
| 101 |
+
plt.legend()
|
| 102 |
+
|
| 103 |
+
plt.tight_layout()
|
| 104 |
+
plt.savefig('output/training_loss.png', dpi=300, bbox_inches='tight')
|
| 105 |
+
plt.close()
|
| 106 |
+
print("✅ บันทึกกราฟ training loss: output/training_loss.png")
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"❌ ไม่สามารถบันทึกกราฟ loss: {e}")
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def load_image(image_path, is_line_art=False):
|
| 8 |
+
"""โหลดภาพและประมวลผลเบื้องต้น"""
|
| 9 |
+
image = tf.io.read_file(image_path)
|
| 10 |
+
image = tf.image.decode_image(image, channels=3)
|
| 11 |
+
image = tf.image.resize(image, [256, 256])
|
| 12 |
+
image = tf.cast(image, tf.float32) / 255.0
|
| 13 |
+
|
| 14 |
+
if is_line_art:
|
| 15 |
+
# แปลงเป็น grayscale และทำให้เส้นคมชัด
|
| 16 |
+
image = tf.image.rgb_to_grayscale(image)
|
| 17 |
+
# ทำให้เส้นดำสนิท พื้นขาว
|
| 18 |
+
image = tf.where(image < 0.5, 0.0, 1.0)
|
| 19 |
+
|
| 20 |
+
return image
|
| 21 |
+
|
| 22 |
+
def extract_line_art_from_colored(colored_image):
|
| 23 |
+
"""สกัดเส้นจากภาพสี (ใช้สร้าง training data)"""
|
| 24 |
+
# แปลงเป็น grayscale
|
| 25 |
+
gray = cv2.cvtColor(colored_image, cv2.COLOR_RGB2GRAY)
|
| 26 |
+
|
| 27 |
+
# ใช้ adaptive threshold เพื่อให้เส้นคมชัด
|
| 28 |
+
line_art = cv2.adaptiveThreshold(
|
| 29 |
+
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
| 30 |
+
cv2.THRESH_BINARY_INV, 11, 2
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# ลบ noise
|
| 34 |
+
kernel = np.ones((2, 2), np.uint8)
|
| 35 |
+
line_art = cv2.morphologyEx(line_art, cv2.MORPH_OPEN, kernel)
|
| 36 |
+
|
| 37 |
+
return line_art
|
| 38 |
+
|
| 39 |
+
def save_comparison(image, prediction, epoch, step):
|
| 40 |
+
"""บันทึกภาพเปรียบเทียบ - แก้ไขเวอร์ชัน"""
|
| 41 |
+
try:
|
| 42 |
+
# แปลง TensorFlow tensor เป็น numpy array
|
| 43 |
+
input_image = image[0].numpy() # (256, 256, 1)
|
| 44 |
+
pred_image = prediction[0].numpy() # (256, 256, 3)
|
| 45 |
+
|
| 46 |
+
# แปลงภาพ input จาก grayscale เป็น RGB
|
| 47 |
+
if input_image.shape[-1] == 1:
|
| 48 |
+
input_image_rgb = np.repeat(input_image, 3, axis=-1)
|
| 49 |
+
else:
|
| 50 |
+
input_image_rgb = input_image
|
| 51 |
+
|
| 52 |
+
# รวมภาพ input และ prediction ข้างกัน
|
| 53 |
+
combined = np.concatenate([input_image_rgb, pred_image], axis=1)
|
| 54 |
+
|
| 55 |
+
# คลิปค่าและแปลงเป็น uint8
|
| 56 |
+
combined = np.clip(combined, 0, 1)
|
| 57 |
+
combined = (combined * 255).astype(np.uint8)
|
| 58 |
+
|
| 59 |
+
# บันทึกภาพ
|
| 60 |
+
comparison = Image.fromarray(combined)
|
| 61 |
+
os.makedirs('output', exist_ok=True)
|
| 62 |
+
comparison.save(f'output/epoch_{epoch:03d}_step_{step:03d}.png')
|
| 63 |
+
print(f"💾 บันทึกภาพเปรียบเทียบ: output/epoch_{epoch:03d}_step_{step:03d}.png")
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"❌ ข้อผิดพลาดในการบันทึกภาพเปรียบเทียบ: {e}")
|
| 67 |
+
|
| 68 |
+
def prepare_directories():
|
| 69 |
+
"""สร้าง directory ที่จำเป็น"""
|
| 70 |
+
os.makedirs('train_images/line_arts', exist_ok=True)
|
| 71 |
+
os.makedirs('train_images/colored', exist_ok=True)
|
| 72 |
+
os.makedirs('output', exist_ok=True)
|
| 73 |
+
print("✅ สร้าง directory เรียบร้อย")
|
| 74 |
+
|
| 75 |
+
# ฟังก์ชันเพิ่มเติมสำหรับการตรวจสอบข้อมูล
|
| 76 |
+
def check_dataset_size(line_art_dir, colored_dir):
|
| 77 |
+
"""ตรวจสอบจำนวนไฟล์ใน dataset"""
|
| 78 |
+
line_art_files = [f for f in os.listdir(line_art_dir) if f.endswith(('.png', '.jpg'))]
|
| 79 |
+
colored_files = [f for f in os.listdir(colored_dir) if f.endswith(('.png', '.jpg'))]
|
| 80 |
+
|
| 81 |
+
print(f"📁 ภาพเส้น: {len(line_art_files)} ไฟล์")
|
| 82 |
+
print(f"🎨 ภาพสี: {len(colored_files)} ไฟล์")
|
| 83 |
+
|
| 84 |
+
if len(line_art_files) != len(colored_files):
|
| 85 |
+
print("⚠️ จำนวนภาพเส้นและภาพสีไม่เท่ากัน!")
|
| 86 |
+
return False
|
| 87 |
+
return True
|