{ "cells": [ { "cell_type": "markdown", "id": "556076dd-8537-4356-8ac3-f4a838b2e896", "metadata": {}, "source": [ "# FEATURE EXTRACTION" ] }, { "cell_type": "markdown", "id": "6b21296a-9b12-40c3-b1fc-c3e316120075", "metadata": {}, "source": [ "This tutorial shows how to use our conformer-based model to extract features from the encoder." ] }, { "cell_type": "markdown", "id": "c97575fc-bb77-43d0-8eb7-f94baf4cdc89", "metadata": {}, "source": [ "**Note** To run this tutorial, please make sure you are in tutorials folder." ] }, { "cell_type": "code", "execution_count": null, "id": "fa81fb4f-6e89-4238-9f47-3f66eedd3c22", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, \"../\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d2fb0de3-af7f-477e-baef-dd595effeef8", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch" ] }, { "cell_type": "code", "execution_count": null, "id": "9993549c-a930-4fd5-a076-411c5e0725b6", "metadata": {}, "outputs": [], "source": [ "import argparse\n", "parser = argparse.ArgumentParser()\n", "args, _ = parser.parse_known_args(args=[])" ] }, { "cell_type": "markdown", "id": "2875bb66-4e96-4c51-85d6-a129748d197a", "metadata": {}, "source": [ "## 1. Build our model" ] }, { "cell_type": "code", "execution_count": null, "id": "e7ff5c7d-30f2-498e-b21b-8b2024d314f8", "metadata": {}, "outputs": [], "source": [ "from espnet.nets.pytorch_backend.e2e_asr_conformer import E2E\n", "from pytorch_lightning import LightningModule\n", "from datamodule.transforms import TextTransform" ] }, { "cell_type": "code", "execution_count": null, "id": "ac03a1e1-3082-43ca-bb44-e58f96464bc3", "metadata": {}, "outputs": [], "source": [ "class ModelModule(LightningModule):\n", " def __init__(self, args):\n", " super().__init__()\n", " self.args = args\n", " self.save_hyperparameters(args)\n", "\n", " self.modality = args.modality\n", " self.text_transform = TextTransform()\n", " self.token_list = self.text_transform.token_list\n", "\n", " self.model = E2E(len(self.token_list), self.modality, ctc_weight=getattr(args, \"ctc_weight\", 0.1))\n", "\n", " def forward(self, x):\n", " x = self.model.frontend(x.unsqueeze(0))\n", " x = self.model.proj_encoder(x)\n", " x, _ = self.model.encoder(x, None)\n", " x = x.squeeze(0)\n", " return x" ] }, { "cell_type": "markdown", "id": "babdaede-b3b0-4a5a-ae5b-66edf25e16a3", "metadata": {}, "source": [ "## 2. Download a pre-trained checkpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "9cc4442d-a112-4822-895c-4faddd8bf808", "metadata": {}, "outputs": [], "source": [ "!wget http://www.doc.ic.ac.uk/~pm4115/autoAVSR/vsr_trlrs3_base.pth -O ./vsr_trlrs3_base.pth\n", "model_path = \"./vsr_trlrs3_base.pth\"" ] }, { "cell_type": "markdown", "id": "a6f10649-5485-47eb-bf91-599e02ba29e7", "metadata": {}, "source": [ "## 3. Load weights from the checkpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "401c84d6-5910-4959-8761-34cbd1d18c56", "metadata": {}, "outputs": [], "source": [ "setattr(args, 'modality', 'video')\n", "model = ModelModule(args)\n", "ckpt = torch.load(model_path, map_location=lambda storage, loc: storage)\n", "model.model.load_state_dict(ckpt)\n", "model.freeze()" ] }, { "cell_type": "markdown", "id": "bde540f5-6bdb-4030-a770-34da73928f39", "metadata": {}, "source": [ "## 4. Use the pre-trained model to extract features" ] }, { "cell_type": "markdown", "id": "c54475a3-6aa1-483e-8bb3-223d31652eee", "metadata": {}, "source": [ "A placeholder x with a shape of (length, num_channel, height, width) is used to represent the input tensor in the lip-reading model." ] }, { "cell_type": "code", "execution_count": null, "id": "0c0fe308-6a0f-4c57-869f-a0114bebdf59", "metadata": {}, "outputs": [], "source": [ "x = torch.randn((10, 1, 88, 88))\n", "with torch.inference_mode():\n", " y = model(x)\n", "print(y.size())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.17" } }, "nbformat": 4, "nbformat_minor": 5 }