File size: 19,452 Bytes
54edee1
 
dd31ccf
 
 
 
 
 
68fad5b
dd31ccf
 
68fad5b
dd31ccf
 
2797e34
 
68fad5b
 
 
 
 
c5ac071
0be44e5
2797e34
 
 
0be44e5
 
 
68fad5b
 
136be26
0be44e5
b040c7e
 
 
 
c5ac071
dd31ccf
 
 
 
 
 
43d63ec
 
 
 
 
 
b040c7e
 
 
 
 
 
 
 
 
 
 
 
 
43d63ec
dd31ccf
2797e34
dd31ccf
68fad5b
dd31ccf
 
256a090
 
 
 
 
 
2797e34
dd31ccf
50f301d
ef07006
63624e2
b5919b4
2797e34
dd31ccf
7c65853
 
 
 
 
 
 
dd31ccf
 
50f301d
7c65853
dd31ccf
2797e34
dd31ccf
2797e34
 
dd31ccf
 
 
7c65853
 
 
dd31ccf
 
 
 
 
 
 
 
2797e34
dd31ccf
a7f2e6c
dd31ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe51c8d
dd31ccf
fe51c8d
dd31ccf
fe51c8d
dd31ccf
 
 
 
 
 
 
 
 
 
 
68fad5b
0be44e5
 
dd31ccf
 
 
68fad5b
dd31ccf
 
 
 
 
 
0be44e5
 
dd31ccf
2797e34
 
 
dd31ccf
 
2797e34
 
 
 
 
 
dd31ccf
2797e34
 
 
 
 
dd31ccf
 
 
2797e34
 
 
 
 
dd31ccf
 
2797e34
 
 
 
 
 
 
 
 
 
136be26
 
 
 
 
 
 
2797e34
 
 
 
 
 
136be26
 
2797e34
 
 
 
5066b17
136be26
 
 
2797e34
 
 
 
 
 
136be26
 
 
2797e34
 
136be26
 
 
 
2797e34
 
 
 
 
 
 
136be26
 
 
 
2797e34
 
 
 
 
 
0be44e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3fa65
 
 
 
 
 
 
 
 
fee163c
 
0be44e5
 
68fad5b
 
ca82116
 
68fad5b
ca82116
b040c7e
0be44e5
329d53d
 
0be44e5
 
 
 
 
 
 
 
 
 
 
68fad5b
 
 
 
 
 
 
 
 
 
 
 
 
 
ca82116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68fad5b
0437ce1
3a724c5
 
68fad5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0be44e5
68fad5b
 
0be44e5
 
 
c5ac071
 
 
 
 
 
0be44e5
ca82116
 
 
 
 
 
0be44e5
 
ca82116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a724c5
0be44e5
 
b040c7e
 
 
bf04385
b040c7e
 
 
 
 
 
 
 
 
ca82116
0be44e5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
#!/usr/bin/env python3

from typing import Dict, Any
import torch
from PIL import Image
import base64
from io import BytesIO
import numpy as np
from diffusers import AutoencoderKL, DDIMScheduler
from einops import repeat
from omegaconf import OmegaConf
from transformers import CLIPVisionModelWithProjection
import cv2
import os
import sys
import skvideo.io
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from src.utils.util import read_frames, get_fps, save_videos_grid

# import onnxruntime as ort
import gc
import subprocess

import requests
import tempfile

from rembg import remove
import onnxruntime as ort
import shutil

import firebase_admin
from firebase_admin import credentials, storage, firestore
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device.type != 'cuda':
    raise ValueError("The model requires a GPU for inference.")

class EndpointHandler():
    def __init__(self, path=""):
        base_dir = os.path.dirname(os.path.abspath(__file__))
        config_path = os.path.join(base_dir, 'configs', 'prompts', 'animation.yaml')

        if not os.path.exists(config_path):
            raise FileNotFoundError(f"The configuration file was not found at: {config_path}")

        service_account_info = os.getenv("FIREBASE_ACCOUNT_INFO")

        if not service_account_info:
            raise ValueError("The FIREBASE_SERVICE_ACCOUNT environment variable is not set.")
        service_account_info = service_account_info.replace('/\\n/g', '\n')

        service_account_info_dict = json.loads(service_account_info)

        cred = credentials.Certificate(service_account_info_dict)
        firebase_admin.initialize_app(cred, {
            'storageBucket': 'quiz-app-edffe.appspot.com'
        })

        self.config = OmegaConf.load(config_path)
        self.weight_dtype = torch.float16
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.pipeline = None
        self._initialize_pipeline()

    def _initialize_pipeline(self):
        base_dir = os.path.dirname(os.path.abspath(__file__))
        config_path = os.path.join(base_dir, 'pretrained_weights', 'sd-vae-ft-mse')

        if not os.path.exists(config_path):
            raise FileNotFoundError(f"The sd-vae-ft-mse folder was not found at: {config_path}")

        vae = AutoencoderKL.from_pretrained(config_path).to(self.device, dtype=self.weight_dtype)

        pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet')
        print("model path is " + pretrained_base_model_path_unet)
        reference_unet = UNet2DConditionModel.from_pretrained(
            pretrained_base_model_path_unet
        ).to(dtype=self.weight_dtype, device=self.device)

        inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')
        motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth')
        denoising_unet_path = os.path.join(base_dir, 'pretrained_weights', 'denoising_unet.pth')
        reference_unet_path = os.path.join(base_dir, 'pretrained_weights', 'reference_unet.pth')
        pose_guider_path = os.path.join(base_dir, 'pretrained_weights', 'pose_guider.pth')
        image_encoder_path = os.path.join(base_dir, 'pretrained_weights', 'image_encoder')

        infer_config = OmegaConf.load(inference_config_path)
        denoising_unet = UNet3DConditionModel.from_pretrained_2d(
            pretrained_base_model_path_unet,
            motion_module_path,
            unet_additional_kwargs=infer_config.unet_additional_kwargs,
        ).to(self.device, dtype=self.weight_dtype)

        pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(self.device, dtype=self.weight_dtype)
        image_enc = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(self.device, dtype=self.weight_dtype)
        sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
        scheduler = DDIMScheduler(**sched_kwargs)

        denoising_unet.load_state_dict(torch.load(denoising_unet_path, map_location="cpu"), strict=False)
        reference_unet.load_state_dict(torch.load(reference_unet_path, map_location="cpu"))
        pose_guider.load_state_dict(torch.load(pose_guider_path, map_location="cpu"))

        self.pipeline = Pose2VideoPipeline(
            vae=vae,
            image_encoder=image_enc,
            reference_unet=reference_unet,
            denoising_unet=denoising_unet,
            pose_guider=pose_guider,
            scheduler=scheduler
        ).to(self.device, dtype=self.weight_dtype)

    def _crop_face(self, image, save_path="cropped_face.jpg", margin=0.5):
        # Convert image to OpenCV format
        cv_image = np.array(image)
        cv_image = cv_image[:, :, ::-1].copy()

        # Load OpenCV face detector
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

        # Detect faces
        gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, 1.1, 4)

        if len(faces) == 0:
            raise ValueError("No faces detected in the reference image.")

        # Crop the first face found with a margin
        x, y, w, h = faces[0]
        x_margin = int(margin * w)
        y_margin = int(margin * h)

        x1 = max(0, x - x_margin)
        y1 = max(0, y - y_margin // 2)  # Less margin at the top
        x2 = min(cv_image.shape[1], x + w + x_margin)
        y2 = min(cv_image.shape[0], y + h + y_margin)  # More margin at the bottom

        cropped_face = cv_image[y1:y2, x1:x2]

        # Convert back to PIL format
        cropped_face = Image.fromarray(cropped_face[:, :, ::-1]).convert("RGB")

        # Save the cropped face
        cropped_face.save(save_path, format="JPEG", quality=95)

        return cropped_face

    def _swap_face(self, source_path, target_video_path, output_path):
        # source_path = "input.jpg"
        # source_image.save(source_path, format="JPEG", quality=95)

        roop.globals.source_path = source_path
        roop.globals.target_path = target_video_path
        roop.globals.output_path = output_path
        roop.globals.frame_processors = ["face_swapper", "face_enhancer"]
        roop.globals.headless = True
        roop.globals.keep_fps = True
        roop.globals.keep_audio = True
        roop.globals.keep_frames = False
        roop.globals.many_faces = False
        # roop.globals.video_encoder = "libx264"
        roop.globals.video_quality = 50
        roop.globals.max_memory = suggest_max_memory()

        # Set GPU execution provider
        roop.globals.execution_providers = decode_execution_providers(["CUDAExecutionProvider"])
        roop.globals.execution_threads = suggest_execution_threads()

        # Ensure onnxruntime is using the GPU
        ort.set_default_logger_severity(3)  # Suppress verbose logging
        providers = ['CUDAExecutionProvider']
        options = ort.SessionOptions()
        options.intra_op_num_threads = 1

        for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
            if hasattr(frame_processor, 'onnx_session'):
                frame_processor.onnx_session.set_providers(providers, options)

        # Clear CUDA cache before starting the face swapping process
        torch.cuda.empty_cache()

        start()

        # Clear CUDA cache after the face swapping process
        for frame_processor in roop.globals.frame_processors:
            del frame_processor
        torch.cuda.empty_cache()

        return os.path.join(os.getcwd(), output_path)

    def print_memory_stat_for_stuff(self, phase, log_file="memory_stats.log"):
        with open(log_file, "a") as f:
            f.write(f"Memory Stats - {phase}:\n")
            f.write(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB\n")
            f.write(f"Reserved memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB\n")
            f.write(f"Max allocated memory: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB\n")
            f.write(f"Max reserved memory: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB\n")
            f.write("="*30 + "\n")

    def convert_to_playable_format(self, input_path, output_path):
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_file:
            temp_output_path = tmp_file.name

        command = f"ffmpeg -i {input_path} -c:v libx264 -preset fast -crf 18 -y {temp_output_path}"

        # Run the command with shell=True
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        print("Conversion STDOUT:", result.stdout)
        print("Conversion STDERR:", result.stderr)
        
        if result.returncode != 0:
            raise RuntimeError(f"FFmpeg conversion failed with exit code {result.returncode}")

        shutil.move(temp_output_path, output_path)

    def run_rife_interpolation(self, video_path, output_path, multi=2, scale=1.0):
        base_dir = os.path.dirname(os.path.abspath(__file__))
        directory = os.path.join(base_dir, "Practical-RIFE", "inference_video.py")
        model_directory = os.path.join(base_dir, "Practical-RIFE", "train_log")
        command = f"python3 {directory} --video={video_path} --output={output_path} --multi={multi} --scale={scale} --model={model_directory}"

        # Run the command with shell=True
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        print(result)
        print(result.stdout)
        print(result.stderr)
        
        if result.returncode != 0:
            raise RuntimeError(f"RIFE interpolation failed with exit code {result.returncode}")
        
        # Overwrite the RIFE output with the converted playable format
        self.convert_to_playable_format(output_path, output_path)

    def speed_up_video(self, input_path, output_path, factor=4):
        command = f"ffmpeg -i {input_path} -filter:v setpts=PTS/{factor} -an {output_path}"

        # Run the command with shell=True
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        print("Speed Up Video STDOUT:", result.stdout)
        print("Speed Up Video STDERR:", result.stderr)
        
        if result.returncode != 0:
            raise RuntimeError(f"FFmpeg speed up failed with exit code {result.returncode}")

    def slow_down_video(self, input_path, output_path, factor=4):
        command = f"ffmpeg -i {input_path} -filter:v setpts={factor}*PTS -an {output_path}"

        # Run the command with shell=True
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        print("Slow Down Video STDOUT:", result.stdout)
        print("Slow Down Video STDERR:", result.stderr)
        
        if result.returncode != 0:
            raise RuntimeError(f"FFmpeg slow down failed with exit code {result.returncode}")

    def download_file(self, url: str, save_path: str):
        response = requests.get(url, stream=True)
        if response.status_code == 200:
            with open(save_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
        else:
            raise ValueError(f"Failed to download file from {url}")

    def print_directory_contents(self, directory):
        for root, dirs, files in os.walk(directory):
            level = root.replace(directory, '').count(os.sep)
            indent = ' ' * 4 * (level)
            print(f"{indent}{os.path.basename(root)}/")
            subindent = ' ' * 4 * (level + 1)
            for f in files:
                print(f"{subindent}{f}")

    def print_directory_contents(self, path='.'):
        for root, dirs, files in os.walk(path):
            level = root.replace(path, '').count(os.sep)
            indent = ' ' * 4 * level
            print(f'{indent}{os.path.basename(root)}/')
            sub_indent = ' ' * 4 * (level + 1)
            for f in files:
                print(f'{sub_indent}{f}')

    def __call__(self, data: Any) -> Dict[str, str]:
        inputs = data.get("inputs", {})
        ref_image_url = inputs.get("ref_image_url", "")
        video_url = inputs.get("video_url", "")
        width = inputs.get("width", 512)
        height = inputs.get("height", 768)
        length = inputs.get("length", 96)
        num_inference_steps = inputs.get("num_inference_steps", 15)
        cfg = inputs.get("cfg", 3.5)
        seed = inputs.get("seed", -1)
        firebase_doc_id = inputs.get("firebase_doc_id", "")

        base_dir = os.path.dirname(os.path.abspath(__file__))

        with tempfile.TemporaryDirectory() as temp_dir:
            print(f"Temporary directory created at {temp_dir}")  # Debug statement
            video_root = os.path.join(temp_dir, "dw_poses_videos")
            os.makedirs(video_root, exist_ok=True)
            downloaded_video_path = os.path.join(video_root, "downloaded_video.mp4")
            downloaded_image_path = os.path.join(video_root, "downloaded_image.jpg")

            self.download_file(video_url, downloaded_video_path)
            self.download_file(ref_image_url, downloaded_image_path)
            ref_image = Image.open(downloaded_image_path)

            original_width, original_height = ref_image.size
            max_dimension = max(original_width, original_height)
            if max_dimension > 600:
                ratio = max_dimension / 600
                width = int(original_width / ratio)
                height = int(original_height / ratio)
            else:
                width = original_width
                height = original_height

            ref_image_no_bg = remove(ref_image)
            ref_image_no_bg_path = os.path.join(video_root, "ref_image_no_bg.png")
            ref_image_no_bg.save(ref_image_no_bg_path)

                        # pose_output_path = os.path.join(temp_dir, "pose_videos")

            # print("we are number 1")
            # # Run the extract_dwpose_from_vid.py script
            # extract_pose_path = os.path.join(base_dir, 'extract_dwpose_from_vid.py')
            # command = f'python3 {extract_pose_path} --video_root {video_root}'

            # # Run the command with shell=True
            # result = subprocess.run(command, shell=True, capture_output=True, text=True)
            # if result.returncode != 0:
            #     raise RuntimeError(f"Error running extract_dwpose_from_vid.py: {result.stderr}")
            # print("we are number 2")

            # # Locate the extracted pose video
            # save_dir = video_root + "_dwpose"
            # print(f"Expected save directory: {save_dir}")  # Debug statement
            # pose_video_path = os.path.join(save_dir, "downloaded_video.mp4")

            # if not os.path.exists(pose_video_path):
            #     print("Contents of the temporary directory:")
            #     self.print_directory_contents(temp_dir)
            #     raise FileNotFoundError(f"The pose video was not found at: {pose_video_path}")

            # Speed up the pose video by 4x
            # sped_up_pose_video_path = os.path.join(temp_dir, "sped_up_pose_video.mp4")
            # self.speed_up_video(downloaded_video_path, sped_up_pose_video_path, factor=2)

            torch.manual_seed(seed)

            pose_images = read_frames(downloaded_video_path)
            src_fps = get_fps(downloaded_video_path)
            
            pose_list = []
            total_length = min(length, len(pose_images))
            for pose_image_pil in pose_images[:total_length]:
                pose_list.append(pose_image_pil)

            video = self.pipeline(
                ref_image_no_bg,
                pose_list,
                width=width,
                height=height,
                video_length=total_length,
                num_inference_steps=num_inference_steps,
                guidance_scale=cfg
            ).videos

            save_dir = os.path.join(temp_dir, "output")
            if not os.path.exists(save_dir):
                os.makedirs(save_dir, exist_ok=True)
            animation_path = os.path.join(save_dir, "animation_output.mp4")
            save_videos_grid(video, animation_path, n_rows=1, fps=src_fps)

            cropped_face_path = os.path.join(save_dir, "cropped_face.jpg")
            cropped_face = self._crop_face(ref_image_no_bg, save_path=cropped_face_path)

            torch.cuda.empty_cache()

            swapped_face_video_path = os.path.join(save_dir, "swapped_face_output.mp4")
            facefusion_script_path = os.path.join(base_dir, 'facefusion', 'core.py')
            swap_command = f'python3 {facefusion_script_path} --source {cropped_face_path} --target {animation_path} --output {swapped_face_video_path}'
            swap_result = subprocess.run(swap_command, shell=True, capture_output=True, text=True)
            if swap_result.returncode != 0:
                raise RuntimeError(f"Error running face swap: {swap_result.stderr}")

            # Slow down the produced video by 4x
            # self.print_directory_contents(temp_dir)
            # slowed_down_animation_path = os.path.join(save_dir, "slowed_down_animation_output.mp4")
            # self.slow_down_video(swapped_face_video_path, slowed_down_animation_path, factor=2)


            torch.cuda.empty_cache()

            #remove background
            # self.print_directory_contents()
            # removed_background_output_path = os.path.join(save_dir, "removed_background_result.mp4")
            # remove_background_script_path = os.path.join(base_dir, "rembg_video.py")
            # remove_background_command = f'python3 {remove_background_script_path} {swapped_face_video_path} {removed_background_output_path}'
            # print("Command is " + remove_background_command)
            # remove_background_result = subprocess.run(remove_background_command, shell=True, capture_output=True, text=True)
            # if remove_background_result.returncode != 0:
            #     raise RuntimeError(f"Error running removing backgriund: {remove_background_result.stderr}")


            # Perform RIFE interpolation
            # self.print_directory_contents(temp_dir)
            # rife_output_path = os.path.join(save_dir, "completed_result.mp4")
            # self.run_rife_interpolation(swapped_face_video_path, rife_output_path, multi=2, scale=0.5)


            with open(swapped_face_video_path, "rb") as video_file:
                video_base64 = base64.b64encode(video_file.read()).decode("utf-8")

            # Upload video to Firebase Storage
            bucket = storage.bucket()
            blob = bucket.blob(f"videos/{firebase_doc_id}/swapped_face_output.mp4")
            blob.upload_from_filename(swapped_face_video_path)
            
            # Make the file publicly accessible
            blob.make_public()

            video_url = blob.public_url

            # Update Firestore document
            db = firestore.client()
            doc_ref = db.collection('danceResults').document(firebase_doc_id)
            doc_ref.update({"videoResultUrl": video_url})

            return {"video": video_base64}