head-animation / inference.py
Daryl Fung
initial commit
c79401e
import os
import sys
import torch
import yaml
import logging
import json
import base64
import io
import string
import random
import threading
from skimage.transform import resize
from skimage import img_as_ubyte
import os
CURRENT_DIR = os.path.dirname(__file__)
sys.path.append(CURRENT_DIR)
logger = logging.getLogger(__name__)
BUCKET_NAME = 'sagemaker-us-east-2-611162955425'
SAVE_VID_DIR = os.path.join('animate', 'video_results')
VIDEO_DIR = '/tmp/'
import imageio
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from demo import make_animation
from autocrop import Cropper
VIDEO_EXAMPLE_DIR = os.path.join(CURRENT_DIR, 'video_examples')
# VIDEO_DIR = os.path.join(CURRENT_DIR, 'video_results')
CONFIG_FILE = os.path.join(CURRENT_DIR, 'config', 'vox-256.yaml')
os.makedirs(VIDEO_DIR, exist_ok=True)
is_cpu = not torch.cuda.is_available()
def stringToImage(datauri):
imgdata = base64.b64decode(datauri.split(',')[1])
return imageio.imread(io.BytesIO(imgdata))
def getRandomVideoPath():
videos = os.listdir(VIDEO_EXAMPLE_DIR)
return os.path.join(VIDEO_EXAMPLE_DIR, random.choice(videos))
def model_fn(model_dir):
global is_cpu
global CONFIG_FILE
checkpoint_path = os.path.join(model_dir, 'vox-cpk.pth.tar')
with open(CONFIG_FILE) as f:
config = yaml.safe_load(f)
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if not is_cpu:
generator.cuda()
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if not is_cpu:
kp_detector.cuda()
if is_cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(checkpoint_path)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
if not is_cpu:
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
return [generator, kp_detector]
def input_fn(request_body, request_content_type):
global logger
logger.info('Deserializing the input data.')
if request_content_type == 'application/json':
input_data = json.loads(request_body)
datauri_image = input_data.get('image', None)
video_id = input_data.get('video_id', None)
output_seconds = int(input_data.get('seconds', None))
if video_id is None and datauri_image is None:
raise Exception('Request must contain either a "video_id" key or an "image" key')
# return if no image is given because we don't need to calculate
if datauri_image is None:
return [None, None, None, input_data['video_id']]
source_image = stringToImage(datauri_image)
# find faces in image
image_size = min(source_image.shape[:2])
source_image = Cropper(width=image_size, height=image_size).crop(source_image)
if source_image is None:
return '{success: Fail, message: "No faces found"}'
reader = imageio.get_reader(getRandomVideoPath())
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
source_image = resize(source_image, (image_size, image_size))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
time_info = {'fps': fps, 'output_seconds': output_seconds}
return [source_image, driving_video, time_info, None]
raise Exception(f'Requested unsupported ContentType in content_type {request_content_type}')
def predict_fn(input_data, model):
global is_cpu
global logger
logger.info('Generating prediction based on input parameters.')
generator, kp_detector = model
source_image, driving_video, time_info, video_id = input_data
# if this is querying then we return without computing the image animation
if video_id:
return {'result': 'accepted', 'video_id': video_id, 'success': True, 'type': 'query'}
# predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True,
# adapt_movement_scale=True, cpu=is_cpu)
# generate an id to send back to the user, so user can query and check if the video has finished processing
letters = string.ascii_lowercase
video_id = ''.join(random.choice(letters) for i in range(5))
def make_animation_long(source_image, driving_video, generator, kp_detector):
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True,
adapt_movement_scale=True, cpu=True)
# video_filename = os.path.join(VIDEO_DIR, f'{len(os.listdir(VIDEO_DIR))}.mp4')
video_filename = os.path.join(VIDEO_DIR, f'{video_id}.mp4')
# blob_video = io.BytesIO()
all_frames = [img_as_ubyte(frame) for frame in predictions]
fps = time_info['fps']
output_seconds = time_info['output_seconds']
generated_fps = fps / 2
# generate longer video by mirroring the frames
if output_seconds is not None:
new_frames = []
num_frames_required = int(output_seconds * generated_fps)
current_index = 1
direction = -1
while len(all_frames) + len(new_frames) < num_frames_required:
new_frames.append(all_frames[current_index * direction])
current_index += 1
if current_index >= len(all_frames):
current_index = 1
direction *= -1
all_frames = all_frames + new_frames
# save video head movement
imageio.mimsave(video_filename, all_frames, 'mp4', fps=generated_fps,
output_params=['-f', 'mp4'])
logger.info(f'video saved at : {video_filename}')
logger.info('success')
thread = threading.Thread(target=make_animation_long, args=[source_image, driving_video, generator, kp_detector])
thread.start()
result = {'result': 'accepted', 'video_id': video_id, 'success': True, 'type': 'generate'}
return result
# return predictions, fps
def is_vid_exist(video_filename):
if os.path.isfile(video_filename):
return True
return False
def output_fn(outputs, content_type):
global logger
logger.info('Serializing the generated output.')
# predictions, fps = outputs
# video_filename = os.path.join(VIDEO_DIR, f'{len(os.listdir(VIDEO_DIR))}.mp4')
# imageio.mimsave(video_filename, [img_as_ubyte(frame) for frame in predictions], fps=fps)
# blob_video = io.BytesIO()
# imageio.mimsave(blob_video, [img_as_ubyte(frame) for frame in predictions], 'mp4', fps=fps, output_params=["-f", "mp4"])
# result = {'result': blob_video, 'success': True}
# return outputs
if outputs['type'] == 'query':
video_filename = os.path.join(VIDEO_DIR, f'{outputs["video_id"]}.mp4')
if is_vid_exist(video_filename):
with open(video_filename, 'rb') as f:
blob_video = f.read()
outputs['video'] = base64.b64encode(blob_video).decode('utf-8')
else:
outputs['success'] = False
if content_type == 'application/json':
return json.dumps(outputs), content_type
# return json.dumps(result), content_type
raise Exception(f'Requested unsupported ContentType in Accept:{content_type}')