# ============================================================================= # Project: WHITE-BOX-CARTOONIZATION # Authors: Amey Thakur & Mega Satish # Date: 2021-08-28 # Repository: https://github.com/Amey-Thakur/WHITE-BOX-CARTOONIZATION # Profiles: https://github.com/Amey-Thakur | https://github.com/msatmod # ============================================================================= """ backend.py ============================================================================= This module handles the interaction with the TensorFlow Artificial Intelligence model. It wraps the complex machine learning code into a simple class `Cartoonizer`. Key Responsibilities: 1. Load the pre-trained neural network weights. 2. Pre-process input images (resize, crop). 3. Run the actual cartoonization inference. 4. Return the processed image. ============================================================================= """ import os import cv2 import numpy as np import tensorflow as tf import sys # Add the 'src' directory to Python's search path so we can import 'network' and 'guided_filter' # These are helper files from the original research paper implementation sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) from src import network from src import guided_filter try: # "tf_slim" is a library used to define complex neural networks. # We rename it to 'slim' because the old code expects that name. import tf_slim as slim except ImportError: print("Warning: tf_slim not found. Trying tensorflow.contrib.slim...") class Cartoonizer: def __init__(self, model_path): """ Initialize the Cartoonizer. :param model_path: Path to the folder containing the saved model weights. """ self.model_path = model_path self.sess = None self.input_photo = None self.final_out = None # Load the model immediately when this object is created self._load_model() def _load_model(self): """ Loads the TensorFlow computation graph and restores the saved weights. This setup happens only once to save time. """ # Disable "Eager Execution". # TensorFlow 2.x runs code immediately (Eager), but this older model # was built for TensorFlow 1.x which builds a "Graph" first. try: tf.compat.v1.disable_eager_execution() except Exception: pass # 1. Define the Input Placeholder (Where the image goes in) # Shape: [Batch_Size, Height, Width, Channels] self.input_photo = tf.compat.v1.placeholder(tf.float32, [1, None, None, 3]) # 2. Build the Generator Network (The "Artist") # This creates the mathematical structure of the AI network_out = network.unet_generator(self.input_photo) # 3. Apply Guided Filter (The "Polisher") # This refines the edges to look more like a cartoon self.final_out = guided_filter.guided_filter(self.input_photo, network_out, r=1, eps=5e-3) # 4. Create a Saver to load the pre-trained knowledge (weights) all_vars = tf.compat.v1.trainable_variables() gene_vars = [var for var in all_vars if 'generator' in var.name] saver = tf.compat.v1.train.Saver(var_list=gene_vars) # 5. Start the TensorFlow Session config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True # Use GPU memory efficiently if available self.sess = tf.compat.v1.Session(config=config) # 6. Initialize and Restore self.sess.run(tf.compat.v1.global_variables_initializer()) saver.restore(self.sess, tf.train.latest_checkpoint(self.model_path)) print("Backend: Model loaded successfully!") def resize_crop(self, image): """ Resizes and crops the image to be compatible with the model. The model works best with dimensions that are multiples of 8. """ h, w, c = np.shape(image) # Limit the size to avoid running out of memory on large images if min(h, w) > 720: if h > w: h, w = int(720*h/w), 720 else: h, w = 720, int(720*w/h) image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) # Ensure dimensions are divisible by 8 h, w = (h//8)*8, (w//8)*8 image = image[:h, :w, :] return image def predict(self, image_bytes): """ The main public method. :param image_bytes: Raw bytes of the uploaded image file. :return: Raw bytes of the cartoonized JPEG image. """ # 1. Decode bytes -> Image Matrix (Height, Width, Colors) nparr = np.frombuffer(image_bytes, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if image is None: print("ERROR: Failed to decode image!") return b'' # 2. Pre-process (Resize/Crop) image = self.resize_crop(image) # 3. Normalize pixel values from [0, 255] to [-1, 1] for the AI batch_image = image.astype(np.float32)/127.5 - 1 # Add batch dimension: [H, W, 3] -> [1, H, W, 3] batch_image = np.expand_dims(batch_image, axis=0) # 4. Run the AI! # feed_dict inputs the image into the placeholder we defined earlier output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image}) # 5. Post-process: [-1, 1] -> [0, 255] output = (np.squeeze(output)+1)*127.5 output = np.clip(output, 0, 255).astype(np.uint8) print(f"DEBUG: Final output shape: {output.shape}") # 6. Encode Image Matrix -> JPEG bytes _, buffer = cv2.imencode('.jpg', output) return buffer.tobytes()