{ "cells": [ { "cell_type": "markdown", "id": "556076dd-8537-4356-8ac3-f4a838b2e896", "metadata": {}, "source": [ "# AUDIO/VISUAL SPEECH RECOGNITION" ] }, { "cell_type": "raw", "id": "27269bb4-af6f-4d93-9141-799400b9131e", "metadata": {}, "source": [ "This tutorial shows how to use our `InferencePipeline` to perform audio or visual speech recognition." ] }, { "cell_type": "markdown", "id": "5e067f7d-cc60-43a0-b65a-7272c9cecc81", "metadata": {}, "source": [ "**Note** This tutorial requires `mediapipe` or `retinaface` detector. Please refer to [preparation](../preparation#setup) for installation." ] }, { "cell_type": "markdown", "id": "c97575fc-bb77-43d0-8eb7-f94baf4cdc89", "metadata": {}, "source": [ "**Note** To run this tutorial, please make sure you are in tutorials folder." ] }, { "cell_type": "code", "execution_count": null, "id": "fa81fb4f-6e89-4238-9f47-3f66eedd3c22", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, \"../\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d2fb0de3-af7f-477e-baef-dd595effeef8", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torchaudio\n", "import torchvision" ] }, { "cell_type": "markdown", "id": "2875bb66-4e96-4c51-85d6-a129748d197a", "metadata": {}, "source": [ "## 1. Build an inference pipeline\n", "\n", "The InferencePipeline carries out the following three steps:\n", "\n", "1. Load audio or video data\n", "2. Run pre-processing functions\n", "3. Run inference" ] }, { "cell_type": "code", "execution_count": null, "id": "1ebfe406-2cd8-4b1a-bf62-32ce1934fc23", "metadata": {}, "outputs": [], "source": [ "import os\n", "from lightning import ModelModule\n", "from datamodule.transforms import AudioTransform, VideoTransform" ] }, { "cell_type": "code", "execution_count": null, "id": "6da3a91c-f515-45f5-8172-ababfd786596", "metadata": {}, "outputs": [], "source": [ "import argparse\n", "parser = argparse.ArgumentParser()\n", "args, _ = parser.parse_known_args(args=[])" ] }, { "cell_type": "code", "execution_count": null, "id": "ac03a1e1-3082-43ca-bb44-e58f96464bc3", "metadata": {}, "outputs": [], "source": [ "class InferencePipeline(torch.nn.Module):\n", " def __init__(self, args, ckpt_path, detector=\"retinaface\"):\n", " super(InferencePipeline, self).__init__()\n", " self.modality = args.modality\n", " if self.modality == \"audio\":\n", " self.audio_transform = AudioTransform(subset=\"test\")\n", " elif self.modality == \"video\":\n", " if detector == \"mediapipe\":\n", " from preparation.detectors.mediapipe.detector import LandmarksDetector\n", " from preparation.detectors.mediapipe.video_process import VideoProcess\n", " self.landmarks_detector = LandmarksDetector()\n", " self.video_process = VideoProcess(convert_gray=False)\n", " elif detector == \"retinaface\":\n", " from preparation.detectors.retinaface.detector import LandmarksDetector\n", " from preparation.detectors.retinaface.video_process import VideoProcess\n", " self.landmarks_detector = LandmarksDetector(device=\"cuda:0\")\n", " self.video_process = VideoProcess(convert_gray=False)\n", " self.video_transform = VideoTransform(subset=\"test\")\n", "\n", " ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)\n", " self.modelmodule = ModelModule(args)\n", " self.modelmodule.model.load_state_dict(ckpt)\n", " self.modelmodule.eval()\n", "\n", " def load_video(self, data_filename):\n", " return torchvision.io.read_video(data_filename, pts_unit=\"sec\")[0].numpy()\n", "\n", " def forward(self, data_filename):\n", " data_filename = os.path.abspath(data_filename)\n", " assert os.path.isfile(data_filename), f\"data_filename: {data_filename} does not exist.\"\n", "\n", " if self.modality == \"audio\":\n", " audio, sample_rate = self.load_audio(data_filename)\n", " audio = self.audio_process(audio, sample_rate)\n", " audio = audio.transpose(1, 0)\n", " audio = self.audio_transform(audio)\n", " with torch.no_grad():\n", " transcript = self.modelmodule(audio)\n", "\n", " if self.modality == \"video\":\n", " video = self.load_video(data_filename)\n", " landmarks = self.landmarks_detector(video)\n", " video = self.video_process(video, landmarks)\n", " video = torch.tensor(video)\n", " video = video.permute((0, 3, 1, 2))\n", " video = self.video_transform(video)\n", " with torch.no_grad():\n", " transcript = self.modelmodule(video)\n", "\n", " return transcript\n", "\n", " def load_audio(self, data_filename):\n", " waveform, sample_rate = torchaudio.load(data_filename, normalize=True)\n", " return waveform, sample_rate\n", "\n", " def load_video(self, data_filename):\n", " return torchvision.io.read_video(data_filename, pts_unit=\"sec\")[0].numpy()\n", "\n", " def audio_process(self, waveform, sample_rate, target_sample_rate=16000):\n", " if sample_rate != target_sample_rate:\n", " waveform = torchaudio.functional.resample(\n", " waveform, sample_rate, target_sample_rate\n", " )\n", " waveform = torch.mean(waveform, dim=0, keepdim=True)\n", " return waveform" ] }, { "cell_type": "markdown", "id": "7dfa0be0-788d-42b4-8873-0cd7ce804eb3", "metadata": {}, "source": [ "## 2. Download a video from web" ] }, { "cell_type": "code", "execution_count": null, "id": "ea438ae5-2c1b-420b-a7f2-0a23e3c22ed3", "metadata": {}, "outputs": [], "source": [ "!wget --content-disposition http://www.doc.ic.ac.uk/~pm4115/autoAVSR/autoavsr_demo_video.mp4 -O ./input.mp4\n", "data_filename = \"input.mp4\"" ] }, { "cell_type": "markdown", "id": "6cde450d-eeea-4ac3-a651-63a1ddd24258", "metadata": {}, "source": [ "## 3. VSR inference" ] }, { "cell_type": "markdown", "id": "58ee0e65-8ef0-4fff-93d3-615b3ad26c77", "metadata": {}, "source": [ "### 3.1 Download a pre-trained model" ] }, { "cell_type": "code", "execution_count": null, "id": "9cc4442d-a112-4822-895c-4faddd8bf808", "metadata": {}, "outputs": [], "source": [ "!wget http://www.doc.ic.ac.uk/~pm4115/autoAVSR/vsr_trlrs3_base.pth -O ./vsr_trlrs3_base.pth\n", "model_path = \"./vsr_trlrs3_base.pth\"" ] }, { "cell_type": "markdown", "id": "aaf16c2c-5293-4824-9692-167ec18909bd", "metadata": {}, "source": [ "### 3.2 Initialize VSR pipeline" ] }, { "cell_type": "code", "execution_count": null, "id": "401c84d6-5910-4959-8761-34cbd1d18c56", "metadata": {}, "outputs": [], "source": [ "setattr(args, 'modality', 'video')\n", "pipeline = InferencePipeline(args, model_path, detector=\"retinaface\")" ] }, { "cell_type": "markdown", "id": "bde540f5-6bdb-4030-a770-34da73928f39", "metadata": {}, "source": [ "### 3.3 Run inference" ] }, { "cell_type": "code", "execution_count": null, "id": "0c0fe308-6a0f-4c57-869f-a0114bebdf59", "metadata": {}, "outputs": [], "source": [ "transcript = pipeline(\"input.mp4\")\n", "print(transcript)" ] }, { "cell_type": "markdown", "id": "7b3e7193-93fb-4c1b-a5a4-af97eb37d0a1", "metadata": {}, "source": [ "## 4. ASR inference" ] }, { "cell_type": "markdown", "id": "06ca5080-e993-43c7-bf93-ec4e367c99ba", "metadata": {}, "source": [ "### 4.1 Download a pre-trained model" ] }, { "cell_type": "code", "execution_count": null, "id": "30a9a7d3-545f-4cb5-8392-341dde7f4bf9", "metadata": {}, "outputs": [], "source": [ "!wget http://www.doc.ic.ac.uk/~pm4115/autoAVSR/asr_trlrs3_base.pth -O ./asr_trlrs3_base.pth\n", "model_path = \"./asr_trlrs3_base.pth\"" ] }, { "cell_type": "markdown", "id": "11071d0b-a20e-4486-a3e0-17996ce9dda4", "metadata": {}, "source": [ "### 4.2 Initialize ASR pipeline" ] }, { "cell_type": "code", "execution_count": null, "id": "c923e7ad-309c-44cd-aa14-d37e83b42714", "metadata": {}, "outputs": [], "source": [ "setattr(args, 'modality', 'audio')\n", "pipeline = InferencePipeline(args, model_path)" ] }, { "cell_type": "markdown", "id": "32ef0be9-aef4-4013-af61-12ebbe9c8671", "metadata": {}, "source": [ "### 4.3 Run inference" ] }, { "cell_type": "code", "execution_count": null, "id": "332254c8-6c18-40a4-8056-0e896b83adf6", "metadata": {}, "outputs": [], "source": [ "transcript = pipeline(\"input.mp4\")\n", "print(transcript)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.17" } }, "nbformat": 4, "nbformat_minor": 5 }