File size: 2,729 Bytes
8096486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import sys
import torch
import torchaudio
import torchvision
import argparse

sys.path.insert(0, "../")

from lightning import ModelModule
from datamodule.transforms import VideoTransform


parser = argparse.ArgumentParser()
parser.add_argument("--video",      type=str, required=True,  help="Path to input video file")
parser.add_argument("--model_path", type=str, required=True,  help="Path to VSR model checkpoint (.pth)")
parser.add_argument("--detector",   type=str, default="retinaface", choices=["retinaface", "mediapipe"])
args, _ = parser.parse_known_args()

class InferencePipeline(torch.nn.Module):
    def __init__(self, ckpt_path, detector="retinaface"):
        super(InferencePipeline, self).__init__()

        if detector == "mediapipe":
            from preparation.detectors.mediapipe.detector import LandmarksDetector
            from preparation.detectors.mediapipe.video_process import VideoProcess
            self.landmarks_detector = LandmarksDetector()
            self.video_process = VideoProcess(convert_gray=False)
        elif detector == "retinaface":
            from preparation.detectors.retinaface.detector import LandmarksDetector
            from preparation.detectors.retinaface.video_process import VideoProcess
            self.landmarks_detector = LandmarksDetector(device="cuda:0")
            self.video_process = VideoProcess(convert_gray=False)

        self.video_transform = VideoTransform(subset="test")

        # load model
        args_model = argparse.Namespace(modality="video")
        #ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
        ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage, weights_only=False)
        self.modelmodule = ModelModule(args_model)
        self.modelmodule.model.load_state_dict(ckpt)
        self.modelmodule.eval()

    def load_video(self, data_filename):
        return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()

    def forward(self, data_filename):
        data_filename = os.path.abspath(data_filename)
        assert os.path.isfile(data_filename), f"File not found: {data_filename}"

        video = self.load_video(data_filename)
        landmarks = self.landmarks_detector(video)
        video = self.video_process(video, landmarks)
        video = torch.tensor(video)
        video = video.permute((0, 3, 1, 2))
        video = self.video_transform(video)

        with torch.no_grad():
            transcript = self.modelmodule(video)

        return transcript

if __name__ == "__main__":
    pipeline = InferencePipeline(ckpt_path=args.model_path, detector=args.detector)
    transcript = pipeline(args.video)
    print("Transcript:", transcript)