| import os |
| import sys |
| import torch |
| import torchaudio |
| import torchvision |
| import argparse |
|
|
| sys.path.insert(0, "../") |
|
|
| from lightning import ModelModule |
| from datamodule.transforms import VideoTransform |
|
|
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--video", type=str, required=True, help="Path to input video file") |
| parser.add_argument("--model_path", type=str, required=True, help="Path to VSR model checkpoint (.pth)") |
| parser.add_argument("--detector", type=str, default="retinaface", choices=["retinaface", "mediapipe"]) |
| args, _ = parser.parse_known_args() |
|
|
| class InferencePipeline(torch.nn.Module): |
| def __init__(self, ckpt_path, detector="retinaface"): |
| super(InferencePipeline, self).__init__() |
|
|
| if detector == "mediapipe": |
| from preparation.detectors.mediapipe.detector import LandmarksDetector |
| from preparation.detectors.mediapipe.video_process import VideoProcess |
| self.landmarks_detector = LandmarksDetector() |
| self.video_process = VideoProcess(convert_gray=False) |
| elif detector == "retinaface": |
| from preparation.detectors.retinaface.detector import LandmarksDetector |
| from preparation.detectors.retinaface.video_process import VideoProcess |
| self.landmarks_detector = LandmarksDetector(device="cuda:0") |
| self.video_process = VideoProcess(convert_gray=False) |
|
|
| self.video_transform = VideoTransform(subset="test") |
|
|
| |
| args_model = argparse.Namespace(modality="video") |
| |
| ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) |
| self.modelmodule = ModelModule(args_model) |
| self.modelmodule.model.load_state_dict(ckpt) |
| self.modelmodule.eval() |
|
|
| def load_video(self, data_filename): |
| return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy() |
|
|
| def forward(self, data_filename): |
| data_filename = os.path.abspath(data_filename) |
| assert os.path.isfile(data_filename), f"File not found: {data_filename}" |
|
|
| video = self.load_video(data_filename) |
| landmarks = self.landmarks_detector(video) |
| video = self.video_process(video, landmarks) |
| video = torch.tensor(video) |
| video = video.permute((0, 3, 1, 2)) |
| video = self.video_transform(video) |
|
|
| with torch.no_grad(): |
| transcript = self.modelmodule(video) |
|
|
| return transcript |
|
|
| if __name__ == "__main__": |
| pipeline = InferencePipeline(ckpt_path=args.model_path, detector=args.detector) |
| transcript = pipeline(args.video) |
| print("Transcript:", transcript) |