Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import torch | |
| import yaml | |
| import logging | |
| import json | |
| import base64 | |
| import io | |
| import string | |
| import random | |
| import threading | |
| from skimage.transform import resize | |
| from skimage import img_as_ubyte | |
| import os | |
| CURRENT_DIR = os.path.dirname(__file__) | |
| sys.path.append(CURRENT_DIR) | |
| logger = logging.getLogger(__name__) | |
| BUCKET_NAME = 'sagemaker-us-east-2-611162955425' | |
| SAVE_VID_DIR = os.path.join('animate', 'video_results') | |
| VIDEO_DIR = '/tmp/' | |
| import imageio | |
| from modules.generator import OcclusionAwareGenerator | |
| from modules.keypoint_detector import KPDetector | |
| from demo import make_animation | |
| from autocrop import Cropper | |
| VIDEO_EXAMPLE_DIR = os.path.join(CURRENT_DIR, 'video_examples') | |
| # VIDEO_DIR = os.path.join(CURRENT_DIR, 'video_results') | |
| CONFIG_FILE = os.path.join(CURRENT_DIR, 'config', 'vox-256.yaml') | |
| os.makedirs(VIDEO_DIR, exist_ok=True) | |
| is_cpu = not torch.cuda.is_available() | |
| def stringToImage(datauri): | |
| imgdata = base64.b64decode(datauri.split(',')[1]) | |
| return imageio.imread(io.BytesIO(imgdata)) | |
| def getRandomVideoPath(): | |
| videos = os.listdir(VIDEO_EXAMPLE_DIR) | |
| return os.path.join(VIDEO_EXAMPLE_DIR, random.choice(videos)) | |
| def model_fn(model_dir): | |
| global is_cpu | |
| global CONFIG_FILE | |
| checkpoint_path = os.path.join(model_dir, 'vox-cpk.pth.tar') | |
| with open(CONFIG_FILE) as f: | |
| config = yaml.safe_load(f) | |
| generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], | |
| **config['model_params']['common_params']) | |
| if not is_cpu: | |
| generator.cuda() | |
| kp_detector = KPDetector(**config['model_params']['kp_detector_params'], | |
| **config['model_params']['common_params']) | |
| if not is_cpu: | |
| kp_detector.cuda() | |
| if is_cpu: | |
| checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) | |
| else: | |
| checkpoint = torch.load(checkpoint_path) | |
| generator.load_state_dict(checkpoint['generator']) | |
| kp_detector.load_state_dict(checkpoint['kp_detector']) | |
| if not is_cpu: | |
| generator = DataParallelWithCallback(generator) | |
| kp_detector = DataParallelWithCallback(kp_detector) | |
| generator.eval() | |
| kp_detector.eval() | |
| return [generator, kp_detector] | |
| def input_fn(request_body, request_content_type): | |
| global logger | |
| logger.info('Deserializing the input data.') | |
| if request_content_type == 'application/json': | |
| input_data = json.loads(request_body) | |
| datauri_image = input_data.get('image', None) | |
| video_id = input_data.get('video_id', None) | |
| output_seconds = int(input_data.get('seconds', None)) | |
| if video_id is None and datauri_image is None: | |
| raise Exception('Request must contain either a "video_id" key or an "image" key') | |
| # return if no image is given because we don't need to calculate | |
| if datauri_image is None: | |
| return [None, None, None, input_data['video_id']] | |
| source_image = stringToImage(datauri_image) | |
| # find faces in image | |
| image_size = min(source_image.shape[:2]) | |
| source_image = Cropper(width=image_size, height=image_size).crop(source_image) | |
| if source_image is None: | |
| return '{success: Fail, message: "No faces found"}' | |
| reader = imageio.get_reader(getRandomVideoPath()) | |
| fps = reader.get_meta_data()['fps'] | |
| driving_video = [] | |
| try: | |
| for im in reader: | |
| driving_video.append(im) | |
| except RuntimeError: | |
| pass | |
| reader.close() | |
| source_image = resize(source_image, (image_size, image_size))[..., :3] | |
| driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] | |
| time_info = {'fps': fps, 'output_seconds': output_seconds} | |
| return [source_image, driving_video, time_info, None] | |
| raise Exception(f'Requested unsupported ContentType in content_type {request_content_type}') | |
| def predict_fn(input_data, model): | |
| global is_cpu | |
| global logger | |
| logger.info('Generating prediction based on input parameters.') | |
| generator, kp_detector = model | |
| source_image, driving_video, time_info, video_id = input_data | |
| # if this is querying then we return without computing the image animation | |
| if video_id: | |
| return {'result': 'accepted', 'video_id': video_id, 'success': True, 'type': 'query'} | |
| # predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, | |
| # adapt_movement_scale=True, cpu=is_cpu) | |
| # generate an id to send back to the user, so user can query and check if the video has finished processing | |
| letters = string.ascii_lowercase | |
| video_id = ''.join(random.choice(letters) for i in range(5)) | |
| def make_animation_long(source_image, driving_video, generator, kp_detector): | |
| predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, | |
| adapt_movement_scale=True, cpu=True) | |
| # video_filename = os.path.join(VIDEO_DIR, f'{len(os.listdir(VIDEO_DIR))}.mp4') | |
| video_filename = os.path.join(VIDEO_DIR, f'{video_id}.mp4') | |
| # blob_video = io.BytesIO() | |
| all_frames = [img_as_ubyte(frame) for frame in predictions] | |
| fps = time_info['fps'] | |
| output_seconds = time_info['output_seconds'] | |
| generated_fps = fps / 2 | |
| # generate longer video by mirroring the frames | |
| if output_seconds is not None: | |
| new_frames = [] | |
| num_frames_required = int(output_seconds * generated_fps) | |
| current_index = 1 | |
| direction = -1 | |
| while len(all_frames) + len(new_frames) < num_frames_required: | |
| new_frames.append(all_frames[current_index * direction]) | |
| current_index += 1 | |
| if current_index >= len(all_frames): | |
| current_index = 1 | |
| direction *= -1 | |
| all_frames = all_frames + new_frames | |
| # save video head movement | |
| imageio.mimsave(video_filename, all_frames, 'mp4', fps=generated_fps, | |
| output_params=['-f', 'mp4']) | |
| logger.info(f'video saved at : {video_filename}') | |
| logger.info('success') | |
| thread = threading.Thread(target=make_animation_long, args=[source_image, driving_video, generator, kp_detector]) | |
| thread.start() | |
| result = {'result': 'accepted', 'video_id': video_id, 'success': True, 'type': 'generate'} | |
| return result | |
| # return predictions, fps | |
| def is_vid_exist(video_filename): | |
| if os.path.isfile(video_filename): | |
| return True | |
| return False | |
| def output_fn(outputs, content_type): | |
| global logger | |
| logger.info('Serializing the generated output.') | |
| # predictions, fps = outputs | |
| # video_filename = os.path.join(VIDEO_DIR, f'{len(os.listdir(VIDEO_DIR))}.mp4') | |
| # imageio.mimsave(video_filename, [img_as_ubyte(frame) for frame in predictions], fps=fps) | |
| # blob_video = io.BytesIO() | |
| # imageio.mimsave(blob_video, [img_as_ubyte(frame) for frame in predictions], 'mp4', fps=fps, output_params=["-f", "mp4"]) | |
| # result = {'result': blob_video, 'success': True} | |
| # return outputs | |
| if outputs['type'] == 'query': | |
| video_filename = os.path.join(VIDEO_DIR, f'{outputs["video_id"]}.mp4') | |
| if is_vid_exist(video_filename): | |
| with open(video_filename, 'rb') as f: | |
| blob_video = f.read() | |
| outputs['video'] = base64.b64encode(blob_video).decode('utf-8') | |
| else: | |
| outputs['success'] = False | |
| if content_type == 'application/json': | |
| return json.dumps(outputs), content_type | |
| # return json.dumps(result), content_type | |
| raise Exception(f'Requested unsupported ContentType in Accept:{content_type}') | |