from transformers import pipeline import torch from nltk.tokenize import sent_tokenize import nltk import pandas as pd import numpy as np import os import sys import pathlib folder_path = pathlib.Path(__file__).parent.resolve() sys.path.append(os.path.join(folder_path,'../')) from utils import load_subtitles_dataset nltk.download('punkt') nltk.download('punkt_tab') class ThemeClassifier(): def __init__(self, theme_list): self.model_name = "facebook/bart-large-mnli" self.device = 0 if torch.cuda.is_available() else 'cpu' self.theme_list = theme_list self.theme_classifier = self.load_model(self.device) def load_model(self,device): theme_classifier = pipeline( "zero-shot-classification", model=self.model_name, device=device ) return theme_classifier def get_themes_inference(self, script): script_sentences = sent_tokenize(script) # Batch Sentence sentence_batch_size=20 script_batches = [] for index in range(0,len(script_sentences),sentence_batch_size): sent = " ".join(script_sentences[index:index+sentence_batch_size]) script_batches.append(sent) # Run Model theme_output = self.theme_classifier( script_batches, self.theme_list, multi_label=True ) # Wrangle Output themes = {} for output in theme_output: for label,score in zip(output['labels'],output['scores']): if label not in themes: themes[label] = [] themes[label].append(score) themes = {key: np.mean(np.array(value)) for key,value in themes.items()} return themes def get_themes(self,dtaset_path, save_path=None): # Read Save Output if Exists if save_path is not None and os.path.exists(save_path): df = pd.read_csv(save_path) return df # load Dataset df = load_subtitles_dataset(dtaset_path) # Run Inference output_themes = df['script'].apply(self.get_themes_inference) themes_df = pd.DataFrame(output_themes.tolist()) df[themes_df.columns] = themes_df # Save output if save_path is not None: df.to_csv(save_path,index=False) return df