Aditeya Kamlesh Prajapati
Add app and modules
8096486
import os
import sys
import torch
import torchaudio
import torchvision
import argparse
sys.path.insert(0, "../")
from lightning import ModelModule
from datamodule.transforms import VideoTransform
parser = argparse.ArgumentParser()
parser.add_argument("--video", type=str, required=True, help="Path to input video file")
parser.add_argument("--model_path", type=str, required=True, help="Path to VSR model checkpoint (.pth)")
parser.add_argument("--detector", type=str, default="retinaface", choices=["retinaface", "mediapipe"])
args, _ = parser.parse_known_args()
class InferencePipeline(torch.nn.Module):
def __init__(self, ckpt_path, detector="retinaface"):
super(InferencePipeline, self).__init__()
if detector == "mediapipe":
from preparation.detectors.mediapipe.detector import LandmarksDetector
from preparation.detectors.mediapipe.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector()
self.video_process = VideoProcess(convert_gray=False)
elif detector == "retinaface":
from preparation.detectors.retinaface.detector import LandmarksDetector
from preparation.detectors.retinaface.video_process import VideoProcess
self.landmarks_detector = LandmarksDetector(device="cuda:0")
self.video_process = VideoProcess(convert_gray=False)
self.video_transform = VideoTransform(subset="test")
# load model
args_model = argparse.Namespace(modality="video")
#ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage, weights_only=False)
self.modelmodule = ModelModule(args_model)
self.modelmodule.model.load_state_dict(ckpt)
self.modelmodule.eval()
def load_video(self, data_filename):
return torchvision.io.read_video(data_filename, pts_unit="sec")[0].numpy()
def forward(self, data_filename):
data_filename = os.path.abspath(data_filename)
assert os.path.isfile(data_filename), f"File not found: {data_filename}"
video = self.load_video(data_filename)
landmarks = self.landmarks_detector(video)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
video = video.permute((0, 3, 1, 2))
video = self.video_transform(video)
with torch.no_grad():
transcript = self.modelmodule(video)
return transcript
if __name__ == "__main__":
pipeline = InferencePipeline(ckpt_path=args.model_path, detector=args.detector)
transcript = pipeline(args.video)
print("Transcript:", transcript)