import os import sys import torch import torchaudio import torchvision import argparse sys.path.insert(0, "../") from lightning import ModelModule from datamodule.transforms import VideoTransform parser = argparse.ArgumentParser() parser.add_argument("--video", type=str, required=True, help="Path to input video file") parser.add_argument("--model_path", type=str, required=True, help="Path to VSR model checkpoint (.pth)") parser.add_argument("--detector", type=str, default="retinaface", choices=["retinaface", "mediapipe"]) args, _ = parser.parse_known_args() class InferencePipeline(torch.nn.Module): def __init__(self, ckpt_path, detector="retinaface"): super(InferencePipeline, self).__init__() if detector == "mediapipe": from preparation.detectors.mediapipe.detector import LandmarksDetector from preparation.detectors.mediapipe.video_process import VideoProcess self.landmarks_detector = LandmarksDetector() self.video_process = VideoProcess(convert_gray=False) elif detector == "retinaface": from preparation.detectors.retinaface.detector import LandmarksDetector from preparation.detectors.retinaface.video_process import VideoProcess self.landmarks_detector = LandmarksDetector(device="cuda:0") self.video_process = VideoProcess(convert_gray=False) self.video_transform = VideoTransform(subset="test") # load model args_model = argparse.Namespace(modality="video") #ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) self.modelmodule = ModelModule(args_model) self.modelmodule.model.load_state_dict(ckpt) self.modelmodule.eval() def load_video(self, data_filename): return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy() def forward(self, data_filename): data_filename = os.path.abspath(data_filename) assert os.path.isfile(data_filename), f"File not found: {data_filename}" video = self.load_video(data_filename) landmarks = self.landmarks_detector(video) video = self.video_process(video, landmarks) video = torch.tensor(video) video = video.permute((0, 3, 1, 2)) video = self.video_transform(video) with torch.no_grad(): transcript = self.modelmodule(video) return transcript if __name__ == "__main__": pipeline = InferencePipeline(ckpt_path=args.model_path, detector=args.detector) transcript = pipeline(args.video) print("Transcript:", transcript)