import os import sys import argparse import torch import gradio as gr from huggingface_hub import hf_hub_download # Add auto_avsr to Python path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "auto_avsr")) from lightning import ModelModule from datamodule.transforms import VideoTransform from preparation.detectors.retinaface.detector import LandmarksDetector from preparation.detectors.retinaface.video_process import VideoProcess # Download VSR model from HuggingFace Hub at startup print("Downloading VSR model from HuggingFace Hub...") model_path = hf_hub_download( repo_id="okregent/visnet-model", filename="vsr_trlrs2lrs3vox2avsp_base.pth", ) print(f"Model ready at: {model_path}") # Initialise args (model expects an argparse namespace) parser = argparse.ArgumentParser() args, _ = parser.parse_known_args(args=[]) setattr(args, "modality", "video") class InferencePipeline(torch.nn.Module): SEGMENT_DURATION = 5 # seconds — matches the LRS3 training clip length MIN_FRAMES = 10 # skip segments shorter than this def __init__(self, args, ckpt_path, detector="retinaface"): super().__init__() self.modality = args.modality device = "cuda:0" if torch.cuda.is_available() else "cpu" self.landmarks_detector = LandmarksDetector(device=device) self.video_process = VideoProcess(convert_gray=False) self.video_transform = VideoTransform(subset="test") ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) self.modelmodule = ModelModule(args) self.modelmodule.model.load_state_dict(ckpt) self.modelmodule.eval() def load_video(self, data_filename): import torchvision frames, _, info = torchvision.io.read_video(data_filename, pts_unit="sec") fps = info.get("video_fps", 25.0) return frames.numpy(), fps def _process_segment(self, segment_frames): landmarks = self.landmarks_detector(segment_frames) processed = self.video_process(segment_frames, landmarks) if processed is None: return "" video_tensor = torch.tensor(processed) video_tensor = video_tensor.permute((0, 3, 1, 2)) video_tensor = self.video_transform(video_tensor) with torch.no_grad(): transcript = self.modelmodule(video_tensor) return transcript.strip() def forward(self, data_filename): data_filename = os.path.abspath(data_filename) assert os.path.isfile(data_filename), f"File not found: {data_filename}" video, fps = self.load_video(data_filename) segment_size = int(fps * self.SEGMENT_DURATION) total_frames = len(video) transcripts = [] for start in range(0, total_frames, segment_size): end = min(start + segment_size, total_frames) segment = video[start:end] if len(segment) < self.MIN_FRAMES: continue result = self._process_segment(segment) if result: transcripts.append(result) return " ".join(transcripts) # Load model once at startup pipeline = InferencePipeline(args, model_path) def transcribe(video_path): if video_path is None: return "Please upload a video file." try: result = pipeline(video_path) if not result: return ( "No speech detected. Make sure the video clearly shows " "a speaker's face (front-facing, good lighting)." ) return result except Exception as e: return f"Error: {str(e)}" demo = gr.Interface( fn=transcribe, inputs=gr.Video(label="Upload Video (mp4 / avi / mov, max 100 MB)"), outputs=gr.Textbox( label="Transcription", lines=6, show_copy_button=True, ), title="VisNet — Visual Speech Recognition", description=( "Upload a video to transcribe speech from lip movements — **no audio required**.\n\n" "**Tips for best results:** front-facing camera, clear face visibility, good lighting.\n\n" "⚠️ Running on CPU — inference may take several minutes for longer videos." ), allow_flagging="never", ) demo.queue() demo.launch()