| import os |
| import sys |
| import argparse |
| import torch |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
|
|
| |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "auto_avsr")) |
|
|
| from lightning import ModelModule |
| from datamodule.transforms import VideoTransform |
| from preparation.detectors.retinaface.detector import LandmarksDetector |
| from preparation.detectors.retinaface.video_process import VideoProcess |
|
|
| |
| print("Downloading VSR model from HuggingFace Hub...") |
| model_path = hf_hub_download( |
| repo_id="okregent/visnet-model", |
| filename="vsr_trlrs2lrs3vox2avsp_base.pth", |
| ) |
| print(f"Model ready at: {model_path}") |
|
|
| |
| parser = argparse.ArgumentParser() |
| args, _ = parser.parse_known_args(args=[]) |
| setattr(args, "modality", "video") |
|
|
|
|
| class InferencePipeline(torch.nn.Module): |
| SEGMENT_DURATION = 5 |
| MIN_FRAMES = 10 |
|
|
| def __init__(self, args, ckpt_path, detector="retinaface"): |
| super().__init__() |
| self.modality = args.modality |
|
|
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| self.landmarks_detector = LandmarksDetector(device=device) |
| self.video_process = VideoProcess(convert_gray=False) |
| self.video_transform = VideoTransform(subset="test") |
|
|
| ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) |
| self.modelmodule = ModelModule(args) |
| self.modelmodule.model.load_state_dict(ckpt) |
| self.modelmodule.eval() |
|
|
| def load_video(self, data_filename): |
| import torchvision |
| frames, _, info = torchvision.io.read_video(data_filename, pts_unit="sec") |
| fps = info.get("video_fps", 25.0) |
| return frames.numpy(), fps |
|
|
| def _process_segment(self, segment_frames): |
| landmarks = self.landmarks_detector(segment_frames) |
| processed = self.video_process(segment_frames, landmarks) |
| if processed is None: |
| return "" |
|
|
| video_tensor = torch.tensor(processed) |
| video_tensor = video_tensor.permute((0, 3, 1, 2)) |
| video_tensor = self.video_transform(video_tensor) |
|
|
| with torch.no_grad(): |
| transcript = self.modelmodule(video_tensor) |
|
|
| return transcript.strip() |
|
|
| def forward(self, data_filename): |
| data_filename = os.path.abspath(data_filename) |
| assert os.path.isfile(data_filename), f"File not found: {data_filename}" |
|
|
| video, fps = self.load_video(data_filename) |
| segment_size = int(fps * self.SEGMENT_DURATION) |
| total_frames = len(video) |
| transcripts = [] |
|
|
| for start in range(0, total_frames, segment_size): |
| end = min(start + segment_size, total_frames) |
| segment = video[start:end] |
| if len(segment) < self.MIN_FRAMES: |
| continue |
| result = self._process_segment(segment) |
| if result: |
| transcripts.append(result) |
|
|
| return " ".join(transcripts) |
|
|
|
|
| |
| pipeline = InferencePipeline(args, model_path) |
|
|
|
|
| def transcribe(video_path): |
| if video_path is None: |
| return "Please upload a video file." |
| try: |
| result = pipeline(video_path) |
| if not result: |
| return ( |
| "No speech detected. Make sure the video clearly shows " |
| "a speaker's face (front-facing, good lighting)." |
| ) |
| return result |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
|
|
| demo = gr.Interface( |
| fn=transcribe, |
| inputs=gr.Video(label="Upload Video (mp4 / avi / mov, max 100 MB)"), |
| outputs=gr.Textbox( |
| label="Transcription", |
| lines=6, |
| show_copy_button=True, |
| ), |
| title="VisNet β Visual Speech Recognition", |
| description=( |
| "Upload a video to transcribe speech from lip movements β **no audio required**.\n\n" |
| "**Tips for best results:** front-facing camera, clear face visibility, good lighting.\n\n" |
| "β οΈ Running on CPU β inference may take several minutes for longer videos." |
| ), |
| allow_flagging="never", |
| ) |
|
|
| demo.queue() |
| demo.launch() |
|
|