visnet / app.py
okregent's picture
Add Gradio app with full VSR pipeline
b5ece2c
import os
import sys
import argparse
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
# Add auto_avsr to Python path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "auto_avsr"))
from lightning import ModelModule
from datamodule.transforms import VideoTransform
from preparation.detectors.retinaface.detector import LandmarksDetector
from preparation.detectors.retinaface.video_process import VideoProcess
# Download VSR model from HuggingFace Hub at startup
print("Downloading VSR model from HuggingFace Hub...")
model_path = hf_hub_download(
repo_id="okregent/visnet-model",
filename="vsr_trlrs2lrs3vox2avsp_base.pth",
)
print(f"Model ready at: {model_path}")
# Initialise args (model expects an argparse namespace)
parser = argparse.ArgumentParser()
args, _ = parser.parse_known_args(args=[])
setattr(args, "modality", "video")
class InferencePipeline(torch.nn.Module):
SEGMENT_DURATION = 5 # seconds β€” matches the LRS3 training clip length
MIN_FRAMES = 10 # skip segments shorter than this
def __init__(self, args, ckpt_path, detector="retinaface"):
super().__init__()
self.modality = args.modality
device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.landmarks_detector = LandmarksDetector(device=device)
self.video_process = VideoProcess(convert_gray=False)
self.video_transform = VideoTransform(subset="test")
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
self.modelmodule = ModelModule(args)
self.modelmodule.model.load_state_dict(ckpt)
self.modelmodule.eval()
def load_video(self, data_filename):
import torchvision
frames, _, info = torchvision.io.read_video(data_filename, pts_unit="sec")
fps = info.get("video_fps", 25.0)
return frames.numpy(), fps
def _process_segment(self, segment_frames):
landmarks = self.landmarks_detector(segment_frames)
processed = self.video_process(segment_frames, landmarks)
if processed is None:
return ""
video_tensor = torch.tensor(processed)
video_tensor = video_tensor.permute((0, 3, 1, 2))
video_tensor = self.video_transform(video_tensor)
with torch.no_grad():
transcript = self.modelmodule(video_tensor)
return transcript.strip()
def forward(self, data_filename):
data_filename = os.path.abspath(data_filename)
assert os.path.isfile(data_filename), f"File not found: {data_filename}"
video, fps = self.load_video(data_filename)
segment_size = int(fps * self.SEGMENT_DURATION)
total_frames = len(video)
transcripts = []
for start in range(0, total_frames, segment_size):
end = min(start + segment_size, total_frames)
segment = video[start:end]
if len(segment) < self.MIN_FRAMES:
continue
result = self._process_segment(segment)
if result:
transcripts.append(result)
return " ".join(transcripts)
# Load model once at startup
pipeline = InferencePipeline(args, model_path)
def transcribe(video_path):
if video_path is None:
return "Please upload a video file."
try:
result = pipeline(video_path)
if not result:
return (
"No speech detected. Make sure the video clearly shows "
"a speaker's face (front-facing, good lighting)."
)
return result
except Exception as e:
return f"Error: {str(e)}"
demo = gr.Interface(
fn=transcribe,
inputs=gr.Video(label="Upload Video (mp4 / avi / mov, max 100 MB)"),
outputs=gr.Textbox(
label="Transcription",
lines=6,
show_copy_button=True,
),
title="VisNet β€” Visual Speech Recognition",
description=(
"Upload a video to transcribe speech from lip movements β€” **no audio required**.\n\n"
"**Tips for best results:** front-facing camera, clear face visibility, good lighting.\n\n"
"⚠️ Running on CPU β€” inference may take several minutes for longer videos."
),
allow_flagging="never",
)
demo.queue()
demo.launch()