Spaces:
Runtime error
Runtime error
| import os | |
| import csv | |
| import json | |
| import torch | |
| import argparse | |
| import pandas as pd | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from collections import defaultdict | |
| from transformers.models.llama.tokenization_llama import LlamaTokenizer | |
| from torch.utils.data import DataLoader | |
| from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration | |
| from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor | |
| from peft import LoraConfig, get_peft_model | |
| from data_utils.xgpt3_dataset import MultiModalDataset | |
| from utils import batchify | |
| import gradio as gr | |
| from entailment_inference import get_scores | |
| from nle_inference import VideoCaptionDataset, get_nle | |
| pretrained_ckpt = "mplugowl7bvideo/" | |
| trained_ckpt = "owl-con/checkpoint-5178/pytorch_model.bin" | |
| tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt) | |
| image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt) | |
| processor = MplugOwlProcessor(image_processor, tokenizer) | |
| # Instantiate model | |
| model = MplugOwlForConditionalGeneration.from_pretrained( | |
| pretrained_ckpt, | |
| torch_dtype=torch.bfloat16, | |
| device_map={'': 'cpu'} | |
| ) | |
| peft_config = LoraConfig( | |
| target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)', | |
| inference_mode=True, | |
| r=32, | |
| lora_alpha=16, | |
| lora_dropout=0.05 | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| with open(trained_ckpt, 'rb') as f: | |
| ckpt = torch.load(f, map_location = torch.device("cpu")) | |
| model.load_state_dict(ckpt) | |
| model = model.to("cuda:0").to(torch.bfloat16) | |
| def inference(videopath, text): | |
| PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. | |
| Human: <|video|> | |
| Human: Does this video entail the description: "{caption}"? | |
| AI: """ | |
| valid_data = MultiModalDataset(videopath, PROMPT.format(caption = text), tokenizer, processor, max_length = 256, loss_objective = 'sequential') | |
| dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify) | |
| score = get_scores(model, tokenizer, dataloader) | |
| if score < 0.5: | |
| dataset = VideoCaptionDataset(videopath, text) | |
| dataloader = DataLoader(dataset) | |
| nle = get_nle(model, processor, tokenizer, dataloader) | |
| else: | |
| nle = "None (NLE is only triggered when entailment score < 0.5)" | |
| return score, nle | |
| demo = gr.Interface(inference, | |
| title="Owl-Con Demo (ode: https://github.com/Hritikbansal/videocon | Paper: https://arxiv.org/abs/2311.10111)", | |
| inputs=[gr.Video(label='input_video'), gr.Textbox(label='input_caption')], | |
| outputs=[gr.Number(label='Entailemnt Score'), gr.Textbox(label='Natural Language Explanation')]) | |
| if __name__ == "__main__": | |
| demo.launch() |