File size: 9,172 Bytes
412f263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import face_recognition
import numpy as np
import os
import torch
from torch.autograd import Variable
from torchvision import transforms
from torchvision.io import write_video
import tempfile
import subprocess
import json
from ffmpy import FFmpeg, FFprobe
from PIL import Image

mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255

def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
    """
    Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
    """

    input_tensor = input_tensor.to(next(your_model.parameters()).device)
    mask = mask.to(next(your_model.parameters()).device)
    small_mask = small_mask.to(next(your_model.parameters()).device)

    n, c, h, w = input_tensor.size()
    output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)

    count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)

    add = 2 if window_size % stride != 0 else 1

    for y in range(0, h - window_size + add, stride):
        for x in range(0, w - window_size + add, stride):
            window = input_tensor[:, :, y:y + window_size, x:x + window_size]

            # Apply the same preprocessing as during training
            input_variable = Variable(window, requires_grad=False)  # Assuming GPU is available

            # Forward pass
            with torch.no_grad():
                output = your_model(input_variable)

            output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
            count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask

    count_tensor = torch.clamp(count_tensor, min=1.0)

    # Average the overlapping regions
    output_tensor /= count_tensor

    # Apply mask
    output_tensor *= mask

    return output_tensor.cpu()


def process_image(your_model, image, video, source_age, target_age=0,
                  window_size=512, stride=256, steps=18):
    input_size = (1024, 1024)
    # Robustly handle image input for face_recognition
    from PIL import Image as PILImage
    import numpy as np
    if isinstance(image, PILImage.Image):
        image = image.convert('RGB')
        image = np.array(image)
    elif isinstance(image, np.ndarray):
        if image.ndim == 2:  # grayscale
            image = np.stack([image]*3, axis=-1)
        elif image.shape[2] == 4:  # RGBA
            image = image[..., :3]
        if image.dtype == np.float32 or image.dtype == np.float64:
            if image.max() <= 1.0:
                image = (image * 255).astype(np.uint8)
            else:
                image = image.astype(np.uint8)
        elif image.dtype != np.uint8:
            image = image.astype(np.uint8)
    else:
        image = np.array(PILImage.fromarray(image).convert('RGB'))
    # Ensure shape is (H, W, 3) and contiguous
    if image.ndim != 3 or image.shape[2] != 3:
        raise ValueError(f"Image must have shape (H, W, 3), got {image.shape}")
    image = np.ascontiguousarray(image, dtype=np.uint8)
    print(f"[DEBUG] image type: {type(image)}, shape: {image.shape}, dtype: {image.dtype}, contiguous: {image.flags['C_CONTIGUOUS']}")
    if video:  # h264 codec requires frame size to be divisible by 2.
        width, height, depth = image.shape
        new_width = width if width % 2 == 0 else width - 1
        new_height = height if height % 2 == 0 else height - 1
        image.resize((new_width, new_height, depth))

    # Diagnostic: try face_recognition on this image, and if it fails, save and reload
    try:
        fl = face_recognition.face_locations(image)[0]
    except Exception as e:
        print(f"[DEBUG] face_locations failed: {e}. Saving image for test...")
        import tempfile
        from PIL import Image as PILImage
        temp_path = tempfile.mktemp(suffix='.png')
        PILImage.fromarray(image).save(temp_path)
        print(f"[DEBUG] Saved image to {temp_path}. Trying face_recognition.load_image_file...")
        loaded_img = face_recognition.load_image_file(temp_path)
        print(f"[DEBUG] loaded_img type: {type(loaded_img)}, shape: {loaded_img.shape}, dtype: {loaded_img.dtype}")
        fl = face_recognition.face_locations(loaded_img)[0]

    # calculate margins
    margin_y_t = int((fl[2] - fl[0]) * .63 * .85)  # larger as the forehead is often cut off
    margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
    margin_x = int((fl[1] - fl[3]) // (2 / .85))
    margin_y_t += 2 * margin_x - margin_y_t - margin_y_b  # make sure square is preserved

    l_y = max([fl[0] - margin_y_t, 0])
    r_y = min([fl[2] + margin_y_b, image.shape[0]])
    l_x = max([fl[3] - margin_x, 0])
    r_x = min([fl[1] + margin_x, image.shape[1]])

    # crop image
    cropped_image = image[l_y:r_y, l_x:r_x, :]

    # Resizing
    orig_size = cropped_image.shape[:2]

    cropped_image = transforms.ToTensor()(cropped_image)

    cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)

    source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
    target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
    input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)

    image = transforms.ToTensor()(image)

    if video:
        # aging in steps
        interval = .8 / steps
        aged_cropped_images = torch.zeros((steps, 3, input_size[1], input_size[0]))
        for i in range(0, steps):
            input_tensor[:, -1, :, :] += interval

            # performing actions on image
            aged_cropped_images[i, ...] = sliding_window_tensor(input_tensor, window_size, stride, your_model)

        # resize back to original size
        aged_cropped_images_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
            aged_cropped_images)

        # re-apply
        image = image.repeat(steps, 1, 1, 1)

        image[:, :, l_y:r_y, l_x:r_x] += aged_cropped_images_resized
        image = torch.clamp(image, 0, 1)
        image = (image * 255).to(torch.uint8)

        output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)

        write_video(output_file.name, image.permute(0, 2, 3, 1), 2)

        return output_file.name

    else:
        # performing actions on image
        aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)

        # resize back to original size
        aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
            aged_cropped_image)

        # re-apply
        image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
        image = torch.clamp(image, 0, 1)

        return transforms.functional.to_pil_image(image)


def process_video(your_model, video_path, source_age, target_age, window_size=512, stride=256, frame_count=0):
    """
    Applying the aging to a video.
    We age as from source_age to target_age, and return an image.
    To limit the number of frames in a video, we can set frame_count.
    """

    # Extracting frames and placing them in a temporary directory
    frames_dir = tempfile.TemporaryDirectory()
    output_template = os.path.join(frames_dir.name, '%04d.jpg')

    if frame_count:
        ff = FFmpeg(
            inputs={video_path: None},
            outputs={output_template: ['-vf', f'select=lt(n\,{frame_count})', '-q:v', '1']}
        )
    else:
        ff = FFmpeg(
            inputs={video_path: None},
            outputs={output_template: ['-q:v', '1']}
        )

    ff.run()

    # Getting framerate (for reconstruction later)
    ff = FFprobe(inputs={video_path: None},
                 global_options=['-v', 'error', '-select_streams', 'v', '-show_entries', 'stream=r_frame_rate', '-of',
                                 'default=noprint_wrappers=1:nokey=1'])
    stdout, _ = ff.run(stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    frame_rate = eval(stdout.decode('utf-8').strip())


    # Applying process_image to frames
    processed_dir = tempfile.TemporaryDirectory()

    for name in os.listdir(frames_dir.name):
        image_path = os.path.join(frames_dir.name, name)
        image = Image.open(image_path).convert('RGB')
        image_aged = process_image(your_model, image, False, source_age, target_age, window_size, stride)
        image_aged.save(os.path.join(processed_dir.name, name))

    # Generating a new video
    input_template = os.path.join(processed_dir.name, '%04d.jpg')
    output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
    ff = FFmpeg(
        inputs={input_template: f'-framerate {frame_rate}'}, global_options=['-y'],
        outputs={output_file.name: ['-c:v', 'libx264', '-pix_fmt', 'yuv420p']}
    )

    ff.run()

    frames_dir.cleanup()
    processed_dir.cleanup()

    return output_file.name