Spaces:
Build error
Build error
File size: 9,172 Bytes
412f263 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import face_recognition
import numpy as np
import os
import torch
from torch.autograd import Variable
from torchvision import transforms
from torchvision.io import write_video
import tempfile
import subprocess
import json
from ffmpy import FFmpeg, FFprobe
from PIL import Image
mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255
def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
"""
Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
"""
input_tensor = input_tensor.to(next(your_model.parameters()).device)
mask = mask.to(next(your_model.parameters()).device)
small_mask = small_mask.to(next(your_model.parameters()).device)
n, c, h, w = input_tensor.size()
output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
add = 2 if window_size % stride != 0 else 1
for y in range(0, h - window_size + add, stride):
for x in range(0, w - window_size + add, stride):
window = input_tensor[:, :, y:y + window_size, x:x + window_size]
# Apply the same preprocessing as during training
input_variable = Variable(window, requires_grad=False) # Assuming GPU is available
# Forward pass
with torch.no_grad():
output = your_model(input_variable)
output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask
count_tensor = torch.clamp(count_tensor, min=1.0)
# Average the overlapping regions
output_tensor /= count_tensor
# Apply mask
output_tensor *= mask
return output_tensor.cpu()
def process_image(your_model, image, video, source_age, target_age=0,
window_size=512, stride=256, steps=18):
input_size = (1024, 1024)
# Robustly handle image input for face_recognition
from PIL import Image as PILImage
import numpy as np
if isinstance(image, PILImage.Image):
image = image.convert('RGB')
image = np.array(image)
elif isinstance(image, np.ndarray):
if image.ndim == 2: # grayscale
image = np.stack([image]*3, axis=-1)
elif image.shape[2] == 4: # RGBA
image = image[..., :3]
if image.dtype == np.float32 or image.dtype == np.float64:
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
elif image.dtype != np.uint8:
image = image.astype(np.uint8)
else:
image = np.array(PILImage.fromarray(image).convert('RGB'))
# Ensure shape is (H, W, 3) and contiguous
if image.ndim != 3 or image.shape[2] != 3:
raise ValueError(f"Image must have shape (H, W, 3), got {image.shape}")
image = np.ascontiguousarray(image, dtype=np.uint8)
print(f"[DEBUG] image type: {type(image)}, shape: {image.shape}, dtype: {image.dtype}, contiguous: {image.flags['C_CONTIGUOUS']}")
if video: # h264 codec requires frame size to be divisible by 2.
width, height, depth = image.shape
new_width = width if width % 2 == 0 else width - 1
new_height = height if height % 2 == 0 else height - 1
image.resize((new_width, new_height, depth))
# Diagnostic: try face_recognition on this image, and if it fails, save and reload
try:
fl = face_recognition.face_locations(image)[0]
except Exception as e:
print(f"[DEBUG] face_locations failed: {e}. Saving image for test...")
import tempfile
from PIL import Image as PILImage
temp_path = tempfile.mktemp(suffix='.png')
PILImage.fromarray(image).save(temp_path)
print(f"[DEBUG] Saved image to {temp_path}. Trying face_recognition.load_image_file...")
loaded_img = face_recognition.load_image_file(temp_path)
print(f"[DEBUG] loaded_img type: {type(loaded_img)}, shape: {loaded_img.shape}, dtype: {loaded_img.dtype}")
fl = face_recognition.face_locations(loaded_img)[0]
# calculate margins
margin_y_t = int((fl[2] - fl[0]) * .63 * .85) # larger as the forehead is often cut off
margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
margin_x = int((fl[1] - fl[3]) // (2 / .85))
margin_y_t += 2 * margin_x - margin_y_t - margin_y_b # make sure square is preserved
l_y = max([fl[0] - margin_y_t, 0])
r_y = min([fl[2] + margin_y_b, image.shape[0]])
l_x = max([fl[3] - margin_x, 0])
r_x = min([fl[1] + margin_x, image.shape[1]])
# crop image
cropped_image = image[l_y:r_y, l_x:r_x, :]
# Resizing
orig_size = cropped_image.shape[:2]
cropped_image = transforms.ToTensor()(cropped_image)
cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)
source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
image = transforms.ToTensor()(image)
if video:
# aging in steps
interval = .8 / steps
aged_cropped_images = torch.zeros((steps, 3, input_size[1], input_size[0]))
for i in range(0, steps):
input_tensor[:, -1, :, :] += interval
# performing actions on image
aged_cropped_images[i, ...] = sliding_window_tensor(input_tensor, window_size, stride, your_model)
# resize back to original size
aged_cropped_images_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
aged_cropped_images)
# re-apply
image = image.repeat(steps, 1, 1, 1)
image[:, :, l_y:r_y, l_x:r_x] += aged_cropped_images_resized
image = torch.clamp(image, 0, 1)
image = (image * 255).to(torch.uint8)
output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
write_video(output_file.name, image.permute(0, 2, 3, 1), 2)
return output_file.name
else:
# performing actions on image
aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)
# resize back to original size
aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
aged_cropped_image)
# re-apply
image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
image = torch.clamp(image, 0, 1)
return transforms.functional.to_pil_image(image)
def process_video(your_model, video_path, source_age, target_age, window_size=512, stride=256, frame_count=0):
"""
Applying the aging to a video.
We age as from source_age to target_age, and return an image.
To limit the number of frames in a video, we can set frame_count.
"""
# Extracting frames and placing them in a temporary directory
frames_dir = tempfile.TemporaryDirectory()
output_template = os.path.join(frames_dir.name, '%04d.jpg')
if frame_count:
ff = FFmpeg(
inputs={video_path: None},
outputs={output_template: ['-vf', f'select=lt(n\,{frame_count})', '-q:v', '1']}
)
else:
ff = FFmpeg(
inputs={video_path: None},
outputs={output_template: ['-q:v', '1']}
)
ff.run()
# Getting framerate (for reconstruction later)
ff = FFprobe(inputs={video_path: None},
global_options=['-v', 'error', '-select_streams', 'v', '-show_entries', 'stream=r_frame_rate', '-of',
'default=noprint_wrappers=1:nokey=1'])
stdout, _ = ff.run(stdout=subprocess.PIPE, stderr=subprocess.PIPE)
frame_rate = eval(stdout.decode('utf-8').strip())
# Applying process_image to frames
processed_dir = tempfile.TemporaryDirectory()
for name in os.listdir(frames_dir.name):
image_path = os.path.join(frames_dir.name, name)
image = Image.open(image_path).convert('RGB')
image_aged = process_image(your_model, image, False, source_age, target_age, window_size, stride)
image_aged.save(os.path.join(processed_dir.name, name))
# Generating a new video
input_template = os.path.join(processed_dir.name, '%04d.jpg')
output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
ff = FFmpeg(
inputs={input_template: f'-framerate {frame_rate}'}, global_options=['-y'],
outputs={output_file.name: ['-c:v', 'libx264', '-pix_fmt', 'yuv420p']}
)
ff.run()
frames_dir.cleanup()
processed_dir.cleanup()
return output_file.name
|