File size: 5,867 Bytes
7a3f743 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# =============================================================================
# Project: WHITE-BOX-CARTOONIZATION
# Authors: Amey Thakur & Mega Satish
# Date: 2021-08-28
# Repository: https://github.com/Amey-Thakur/WHITE-BOX-CARTOONIZATION
# Profiles: https://github.com/Amey-Thakur | https://github.com/msatmod
# =============================================================================
"""
backend.py
=============================================================================
This module handles the interaction with the TensorFlow Artificial Intelligence model.
It wraps the complex machine learning code into a simple class `Cartoonizer`.
Key Responsibilities:
1. Load the pre-trained neural network weights.
2. Pre-process input images (resize, crop).
3. Run the actual cartoonization inference.
4. Return the processed image.
=============================================================================
"""
import os
import cv2
import numpy as np
import tensorflow as tf
import sys
# Add the 'src' directory to Python's search path so we can import 'network' and 'guided_filter'
# These are helper files from the original research paper implementation
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
from src import network
from src import guided_filter
try:
# "tf_slim" is a library used to define complex neural networks.
# We rename it to 'slim' because the old code expects that name.
import tf_slim as slim
except ImportError:
print("Warning: tf_slim not found. Trying tensorflow.contrib.slim...")
class Cartoonizer:
def __init__(self, model_path):
"""
Initialize the Cartoonizer.
:param model_path: Path to the folder containing the saved model weights.
"""
self.model_path = model_path
self.sess = None
self.input_photo = None
self.final_out = None
# Load the model immediately when this object is created
self._load_model()
def _load_model(self):
"""
Loads the TensorFlow computation graph and restores the saved weights.
This setup happens only once to save time.
"""
# Disable "Eager Execution".
# TensorFlow 2.x runs code immediately (Eager), but this older model
# was built for TensorFlow 1.x which builds a "Graph" first.
try:
tf.compat.v1.disable_eager_execution()
except Exception:
pass
# 1. Define the Input Placeholder (Where the image goes in)
# Shape: [Batch_Size, Height, Width, Channels]
self.input_photo = tf.compat.v1.placeholder(tf.float32, [1, None, None, 3])
# 2. Build the Generator Network (The "Artist")
# This creates the mathematical structure of the AI
network_out = network.unet_generator(self.input_photo)
# 3. Apply Guided Filter (The "Polisher")
# This refines the edges to look more like a cartoon
self.final_out = guided_filter.guided_filter(self.input_photo, network_out, r=1, eps=5e-3)
# 4. Create a Saver to load the pre-trained knowledge (weights)
all_vars = tf.compat.v1.trainable_variables()
gene_vars = [var for var in all_vars if 'generator' in var.name]
saver = tf.compat.v1.train.Saver(var_list=gene_vars)
# 5. Start the TensorFlow Session
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True # Use GPU memory efficiently if available
self.sess = tf.compat.v1.Session(config=config)
# 6. Initialize and Restore
self.sess.run(tf.compat.v1.global_variables_initializer())
saver.restore(self.sess, tf.train.latest_checkpoint(self.model_path))
print("Backend: Model loaded successfully!")
def resize_crop(self, image):
"""
Resizes and crops the image to be compatible with the model.
The model works best with dimensions that are multiples of 8.
"""
h, w, c = np.shape(image)
# Limit the size to avoid running out of memory on large images
if min(h, w) > 720:
if h > w:
h, w = int(720*h/w), 720
else:
h, w = 720, int(720*w/h)
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
# Ensure dimensions are divisible by 8
h, w = (h//8)*8, (w//8)*8
image = image[:h, :w, :]
return image
def predict(self, image_bytes):
"""
The main public method.
:param image_bytes: Raw bytes of the uploaded image file.
:return: Raw bytes of the cartoonized JPEG image.
"""
# 1. Decode bytes -> Image Matrix (Height, Width, Colors)
nparr = np.frombuffer(image_bytes, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
print("ERROR: Failed to decode image!")
return b''
# 2. Pre-process (Resize/Crop)
image = self.resize_crop(image)
# 3. Normalize pixel values from [0, 255] to [-1, 1] for the AI
batch_image = image.astype(np.float32)/127.5 - 1
# Add batch dimension: [H, W, 3] -> [1, H, W, 3]
batch_image = np.expand_dims(batch_image, axis=0)
# 4. Run the AI!
# feed_dict inputs the image into the placeholder we defined earlier
output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image})
# 5. Post-process: [-1, 1] -> [0, 255]
output = (np.squeeze(output)+1)*127.5
output = np.clip(output, 0, 255).astype(np.uint8)
print(f"DEBUG: Final output shape: {output.shape}")
# 6. Encode Image Matrix -> JPEG bytes
_, buffer = cv2.imencode('.jpg', output)
return buffer.tobytes()
|