Commit ·
23811f4
1
Parent(s): e4ea306
Delete train.py
Browse files
train.py
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import random
|
| 3 |
-
import numpy as np
|
| 4 |
-
from glob import glob
|
| 5 |
-
from PIL import Image, ImageOps
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
-
import tensorflow as tf
|
| 8 |
-
from tensorflow import keras
|
| 9 |
-
from tensorflow.keras import layers
|
| 10 |
-
from model import get_model
|
| 11 |
-
|
| 12 |
-
# functions to create the dataset
|
| 13 |
-
random.seed(1)
|
| 14 |
-
IMAGE_SIZE = 128
|
| 15 |
-
BATCH_SIZE = 4
|
| 16 |
-
MAX_TRAIN_IMAGES = 300
|
| 17 |
-
|
| 18 |
-
def autocontrast(tensor, cutoff=0):
|
| 19 |
-
tensor = tf.cast(tensor, dtype=tf.float32)
|
| 20 |
-
min_val = tf.reduce_min(tensor)
|
| 21 |
-
max_val = tf.reduce_max(tensor)
|
| 22 |
-
range_val = max_val - min_val
|
| 23 |
-
adjusted_tensor = tf.clip_by_value(tf.cast(tf.round((tensor - min_val - cutoff) * (255 / (range_val - 2 * cutoff))), tf.uint8), 0, 255)
|
| 24 |
-
return adjusted_tensor
|
| 25 |
-
|
| 26 |
-
def read_image(image_path):
|
| 27 |
-
image = tf.io.read_file(image_path)
|
| 28 |
-
image = tf.image.decode_png(image, channels=3)
|
| 29 |
-
image = autocontrast(image)
|
| 30 |
-
image.set_shape([None, None, 3])
|
| 31 |
-
image = tf.cast(image, dtype=tf.float32) / 255
|
| 32 |
-
return image
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def random_crop(low_image, enhanced_image):
|
| 36 |
-
low_image_shape = tf.shape(low_image)[:2]
|
| 37 |
-
low_w = tf.random.uniform(
|
| 38 |
-
shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32
|
| 39 |
-
)
|
| 40 |
-
low_h = tf.random.uniform(
|
| 41 |
-
shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32
|
| 42 |
-
)
|
| 43 |
-
enhanced_w = low_w
|
| 44 |
-
enhanced_h = low_h
|
| 45 |
-
low_image_cropped = low_image[
|
| 46 |
-
low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
|
| 47 |
-
]
|
| 48 |
-
enhanced_image_cropped = enhanced_image[
|
| 49 |
-
enhanced_h : enhanced_h + IMAGE_SIZE, enhanced_w : enhanced_w + IMAGE_SIZE
|
| 50 |
-
]
|
| 51 |
-
return low_image_cropped, enhanced_image_cropped
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def load_data(low_light_image_path, enhanced_image_path):
|
| 55 |
-
low_light_image = read_image(low_light_image_path)
|
| 56 |
-
enhanced_image = read_image(enhanced_image_path)
|
| 57 |
-
low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)
|
| 58 |
-
return low_light_image, enhanced_image
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def get_dataset(low_light_images, enhanced_images):
|
| 62 |
-
dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
|
| 63 |
-
dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
|
| 64 |
-
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
|
| 65 |
-
return dataset
|
| 66 |
-
|
| 67 |
-
# Loss functions
|
| 68 |
-
|
| 69 |
-
class CustomLoss:
|
| 70 |
-
def __init__(self, perceptual_loss_model):
|
| 71 |
-
self.perceptual_loss_model = perceptual_loss_model
|
| 72 |
-
def perceptual_loss(self, y_true, y_pred):
|
| 73 |
-
y_true_features = self.perceptual_loss_model(y_true)
|
| 74 |
-
y_pred_features = self.perceptual_loss_model(y_pred)
|
| 75 |
-
loss = tf.reduce_mean(tf.square(y_true_features[0] - y_pred_features[0])) + tf.reduce_mean(tf.square(y_true_features[1] - y_pred_features[1]))
|
| 76 |
-
return loss
|
| 77 |
-
def charbonnier_loss(self, y_true, y_pred):
|
| 78 |
-
return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))
|
| 79 |
-
def __call__(self, y_true, y_pred):
|
| 80 |
-
return 0.5*self.perceptual_loss(y_true, y_pred) + 0.4*self.charbonnier_loss(y_true, y_pred)
|
| 81 |
-
|
| 82 |
-
def peak_signal_noise_ratio(y_true, y_pred):
|
| 83 |
-
return tf.image.psnr(y_pred, y_true, max_val=255.0)
|
| 84 |
-
|
| 85 |
-
def main():
|
| 86 |
-
train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
|
| 87 |
-
train_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]
|
| 88 |
-
|
| 89 |
-
val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
|
| 90 |
-
val_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]
|
| 91 |
-
|
| 92 |
-
train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
|
| 93 |
-
val_dataset = get_dataset(val_low_light_images, val_enhanced_images)
|
| 94 |
-
|
| 95 |
-
#Model for calculating perceptual loss
|
| 96 |
-
vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
|
| 97 |
-
for layer in vgg.layers:
|
| 98 |
-
layer.trainable = False #Freeze all the layers, since this model is for evaluation only
|
| 99 |
-
outputs = [vgg.get_layer('block3_conv3').output, vgg.get_layer('block4_conv3').output]
|
| 100 |
-
perceptual_loss_model = tf.keras.models.Model(inputs=vgg.input, outputs=outputs)
|
| 101 |
-
|
| 102 |
-
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
|
| 103 |
-
loss = CustomLoss(perceptual_loss_model)
|
| 104 |
-
model = get_model()
|
| 105 |
-
|
| 106 |
-
model.compile(
|
| 107 |
-
optimizer=optimizer, loss=loss, metrics=[peak_signal_noise_ratio]
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
history = model.fit(train_dataset, validation_data=val_dataset, epochs=50)
|
| 111 |
-
model.save_weights("model.h5")
|
| 112 |
-
|
| 113 |
-
if __name__ == "__main__":
|
| 114 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|