File size: 5,867 Bytes
7a3f743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# =============================================================================
# Project: WHITE-BOX-CARTOONIZATION
# Authors: Amey Thakur & Mega Satish
# Date: 2021-08-28
# Repository: https://github.com/Amey-Thakur/WHITE-BOX-CARTOONIZATION
# Profiles: https://github.com/Amey-Thakur | https://github.com/msatmod
# =============================================================================

"""
backend.py
=============================================================================
This module handles the interaction with the TensorFlow Artificial Intelligence model.
It wraps the complex machine learning code into a simple class `Cartoonizer`.

Key Responsibilities:
1.  Load the pre-trained neural network weights.
2.  Pre-process input images (resize, crop).
3.  Run the actual cartoonization inference.
4.  Return the processed image.
=============================================================================
"""
import os
import cv2
import numpy as np
import tensorflow as tf
import sys

# Add the 'src' directory to Python's search path so we can import 'network' and 'guided_filter'
# These are helper files from the original research paper implementation
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))

from src import network
from src import guided_filter

try:
    # "tf_slim" is a library used to define complex neural networks.
    # We rename it to 'slim' because the old code expects that name.
    import tf_slim as slim
except ImportError:
    print("Warning: tf_slim not found. Trying tensorflow.contrib.slim...")

class Cartoonizer:
    def __init__(self, model_path):
        """
        Initialize the Cartoonizer.
        :param model_path: Path to the folder containing the saved model weights.
        """
        self.model_path = model_path
        self.sess = None
        self.input_photo = None
        self.final_out = None
        
        # Load the model immediately when this object is created
        self._load_model()

    def _load_model(self):
        """
        Loads the TensorFlow computation graph and restores the saved weights.
        This setup happens only once to save time.
        """
        # Disable "Eager Execution".
        # TensorFlow 2.x runs code immediately (Eager), but this older model
        # was built for TensorFlow 1.x which builds a "Graph" first.
        try:
            tf.compat.v1.disable_eager_execution()
        except Exception:
            pass

        # 1. Define the Input Placeholder (Where the image goes in)
        # Shape: [Batch_Size, Height, Width, Channels]
        self.input_photo = tf.compat.v1.placeholder(tf.float32, [1, None, None, 3])
        
        # 2. Build the Generator Network (The "Artist")
        # This creates the mathematical structure of the AI
        network_out = network.unet_generator(self.input_photo)
        
        # 3. Apply Guided Filter (The "Polisher")
        # This refines the edges to look more like a cartoon
        self.final_out = guided_filter.guided_filter(self.input_photo, network_out, r=1, eps=5e-3)

        # 4. Create a Saver to load the pre-trained knowledge (weights)
        all_vars = tf.compat.v1.trainable_variables()
        gene_vars = [var for var in all_vars if 'generator' in var.name]
        saver = tf.compat.v1.train.Saver(var_list=gene_vars)
        
        # 5. Start the TensorFlow Session
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True # Use GPU memory efficiently if available
        self.sess = tf.compat.v1.Session(config=config)

        # 6. Initialize and Restore
        self.sess.run(tf.compat.v1.global_variables_initializer())
        saver.restore(self.sess, tf.train.latest_checkpoint(self.model_path))
        print("Backend: Model loaded successfully!")

    def resize_crop(self, image):
        """
        Resizes and crops the image to be compatible with the model.
        The model works best with dimensions that are multiples of 8.
        """
        h, w, c = np.shape(image)
        # Limit the size to avoid running out of memory on large images
        if min(h, w) > 720:
            if h > w:
                h, w = int(720*h/w), 720
            else:
                h, w = 720, int(720*w/h)
        image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
        
        # Ensure dimensions are divisible by 8
        h, w = (h//8)*8, (w//8)*8
        image = image[:h, :w, :]
        return image

    def predict(self, image_bytes):
        """
        The main public method.
        :param image_bytes: Raw bytes of the uploaded image file.
        :return: Raw bytes of the cartoonized JPEG image.
        """
        # 1. Decode bytes -> Image Matrix (Height, Width, Colors)
        nparr = np.frombuffer(image_bytes, np.uint8)
        image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        
        if image is None:
            print("ERROR: Failed to decode image!")
            return b''

        # 2. Pre-process (Resize/Crop)
        image = self.resize_crop(image)
        
        # 3. Normalize pixel values from [0, 255] to [-1, 1] for the AI
        batch_image = image.astype(np.float32)/127.5 - 1
        # Add batch dimension: [H, W, 3] -> [1, H, W, 3]
        batch_image = np.expand_dims(batch_image, axis=0)
        
        # 4. Run the AI!
        # feed_dict inputs the image into the placeholder we defined earlier
        output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image})
        
        # 5. Post-process: [-1, 1] -> [0, 255]
        output = (np.squeeze(output)+1)*127.5
        output = np.clip(output, 0, 255).astype(np.uint8)
        print(f"DEBUG: Final output shape: {output.shape}")
        
        # 6. Encode Image Matrix -> JPEG bytes
        _, buffer = cv2.imencode('.jpg', output)
        return buffer.tobytes()