Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, VisionEncoderDecoderModel | |
| import utils | |
| class Inference: | |
| def __init__(self, decoder_model_name, max_length=50): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name) | |
| self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained('armgabrielyan/video-summarization') | |
| self.encoder_decoder_model.to(self.device) | |
| self.max_length = max_length | |
| def generate_text(self, video, encoder_model_name): | |
| if isinstance(video, str): | |
| pixel_values = utils.video2image_from_path(video, encoder_model_name) | |
| else: | |
| pixel_values = video | |
| if not self.tokenizer.pad_token: | |
| self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer)) | |
| generated_ids = self.encoder_decoder_model.generate(pixel_values.unsqueeze(0).to(self.device),early_stopping=True, max_length=self.max_length,num_beams=4, | |
| no_repeat_ngram_size=2 ) | |
| generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |