Spaces:
Runtime error
Runtime error
Delete pipeline.py
Browse files- pipeline.py +0 -73
pipeline.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
#! /usr/bin/env python
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
|
| 4 |
-
# Copyright 2023 Imperial College London (Pingchuan Ma)
|
| 5 |
-
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
-
|
| 7 |
-
import os
|
| 8 |
-
import torch
|
| 9 |
-
import pickle
|
| 10 |
-
from configparser import ConfigParser
|
| 11 |
-
|
| 12 |
-
from pipelines.model import AVSR
|
| 13 |
-
from pipelines.data.data_module import AVSRDataLoader
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class InferencePipeline(torch.nn.Module):
|
| 17 |
-
def __init__(self, config_filename, detector="retinaface", face_track=False, device="cuda:0"):
|
| 18 |
-
super(InferencePipeline, self).__init__()
|
| 19 |
-
assert os.path.isfile(config_filename), f"config_filename: {config_filename} does not exist."
|
| 20 |
-
|
| 21 |
-
config = ConfigParser()
|
| 22 |
-
config.read(config_filename)
|
| 23 |
-
|
| 24 |
-
# modality configuration
|
| 25 |
-
modality = config.get("input", "modality")
|
| 26 |
-
|
| 27 |
-
self.modality = modality
|
| 28 |
-
# data configuration
|
| 29 |
-
input_v_fps = config.getfloat("input", "v_fps")
|
| 30 |
-
model_v_fps = config.getfloat("model", "v_fps")
|
| 31 |
-
|
| 32 |
-
# model configuration
|
| 33 |
-
model_path = config.get("model","model_path")
|
| 34 |
-
model_conf = config.get("model","model_conf")
|
| 35 |
-
|
| 36 |
-
# language model configuration
|
| 37 |
-
rnnlm = config.get("model", "rnnlm")
|
| 38 |
-
rnnlm_conf = config.get("model", "rnnlm_conf")
|
| 39 |
-
penalty = config.getfloat("decode", "penalty")
|
| 40 |
-
ctc_weight = config.getfloat("decode", "ctc_weight")
|
| 41 |
-
lm_weight = config.getfloat("decode", "lm_weight")
|
| 42 |
-
beam_size = config.getint("decode", "beam_size")
|
| 43 |
-
|
| 44 |
-
self.dataloader = AVSRDataLoader(modality, speed_rate=input_v_fps/model_v_fps, detector=detector)
|
| 45 |
-
self.model = AVSR(modality, model_path, model_conf, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size, device)
|
| 46 |
-
if face_track and self.modality in ["video", "audiovisual"]:
|
| 47 |
-
if detector == "mediapipe":
|
| 48 |
-
from pipelines.detectors.mediapipe.detector import LandmarksDetector
|
| 49 |
-
self.landmarks_detector = LandmarksDetector()
|
| 50 |
-
if detector == "retinaface":
|
| 51 |
-
from pipelines.detectors.retinaface.detector import LandmarksDetector
|
| 52 |
-
self.landmarks_detector = LandmarksDetector(device="cuda:0")
|
| 53 |
-
else:
|
| 54 |
-
self.landmarks_detector = None
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def process_landmarks(self, data_filename, landmarks_filename):
|
| 58 |
-
if self.modality == "audio":
|
| 59 |
-
return None
|
| 60 |
-
if self.modality in ["video", "audiovisual"]:
|
| 61 |
-
if isinstance(landmarks_filename, str):
|
| 62 |
-
landmarks = pickle.load(open(landmarks_filename, "rb"))
|
| 63 |
-
else:
|
| 64 |
-
landmarks = self.landmarks_detector(data_filename)
|
| 65 |
-
return landmarks
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def forward(self, data_filename, landmarks_filename=None):
|
| 69 |
-
assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist."
|
| 70 |
-
landmarks = self.process_landmarks(data_filename, landmarks_filename)
|
| 71 |
-
data = self.dataloader.load_data(data_filename, landmarks)
|
| 72 |
-
transcript = self.model.infer(data)
|
| 73 |
-
return transcript
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|