from transformers import pipeline import torch from nltk.tokenize import sent_tokenize import nltk import pandas as pd import numpy as np import os import sys import pathlib folder_path = pathlib.Path(__file__).parent.resolve() sys.path.append(os.path.join(folder_path,'../')) from utils import load_subtitles_dataset nltk.download('punkt') nltk.download('punkt_tab') class ThemeClassifier(): def __init__(self, theme_list): self.model_name = "MoritzLaurer/deberta-v3-large-zeroshot-v2" self.device = 0 if torch.cuda.is_available() else 'cpu' self.theme_list = theme_list # Lazy: only load the 1.5GB zero-shot model if we actually run inference # (reading a precomputed stub needs no model — keeps the Space light). self.theme_classifier = None def _classifier(self): if self.theme_classifier is None: self.theme_classifier = self.load_model(self.device) return self.theme_classifier def load_model(self,device): theme_classifier = pipeline( "zero-shot-classification", model=self.model_name, device=device ) return theme_classifier def get_themes_inference(self, script): script_sentences = sent_tokenize(script) # Batch Sentence sentence_batch_size=20 script_batches = [] for index in range(0,len(script_sentences),sentence_batch_size): sent = " ".join(script_sentences[index:index+sentence_batch_size]) script_batches.append(sent) # Run Model theme_output = self._classifier()( script_batches, self.theme_list, multi_label=True ) # Wrangle Output themes = {} for output in theme_output: for label,score in zip(output['labels'],output['scores']): if label not in themes: themes[label] = [] themes[label].append(score) themes = {key: np.mean(np.array(value)) for key,value in themes.items()} return themes def get_themes(self, dataset_path, save_path=None): # Read Save Output if Exists if save_path is not None and os.path.exists(save_path): df = pd.read_csv(save_path) return df # load Dataset df = load_subtitles_dataset(dataset_path) # Run Inference output_themes = df['script'].apply(self.get_themes_inference) themes_df = pd.DataFrame(output_themes.tolist()) df[themes_df.columns] = themes_df # Save output if save_path is not None: df.to_csv(save_path,index=False) return df