| from simpletransformers.classification import ClassificationModel, ClassificationArgs | |
| from typing import Dict, List, Any | |
| import pandas as pd | |
| import webvtt | |
| from datetime import datetime | |
| import torch | |
| import spacy | |
| class EndpointHandler(): | |
| def __init__(self, path="."): | |
| print("Loading models...") | |
| cuda_available = torch.cuda.is_available() | |
| self.model = ClassificationModel( | |
| "roberta", path, use_cuda=cuda_available | |
| ) | |
| def __call__(self, data_file: str) -> List[Dict[str, Any]]: | |
| ''' data_file is a str pointing to filename of type .vtt ''' | |
| return [] |