|
|
import torch |
|
|
import webvtt |
|
|
import os |
|
|
import cv2 |
|
|
from minigpt4.common.eval_utils import prepare_texts, init_model |
|
|
from minigpt4.conversation.conversation import CONV_VISION |
|
|
from torchvision import transforms |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
import soundfile as sf |
|
|
import argparse |
|
|
import moviepy.editor as mp |
|
|
import gradio as gr |
|
|
from pytubefix import YouTube |
|
|
import shutil |
|
|
from PIL import Image |
|
|
from moviepy.editor import VideoFileClip |
|
|
import torch |
|
|
import random |
|
|
import numpy as np |
|
|
import torch.backends.cudnn as cudnn |
|
|
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
import cv2 |
|
|
import webvtt |
|
|
from bisect import bisect_left |
|
|
import whisper |
|
|
import time |
|
|
from datetime import timedelta |
|
|
|
|
|
def format_timestamp(seconds): |
|
|
td = timedelta(seconds=seconds) |
|
|
total_seconds = int(td.total_seconds()) |
|
|
milliseconds = int(td.microseconds / 1000) |
|
|
hours, remainder = divmod(total_seconds, 3600) |
|
|
minutes, seconds = divmod(remainder, 60) |
|
|
return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}" |
|
|
def extract_video_info(video_path,max_images_length): |
|
|
clip = VideoFileClip(video_path) |
|
|
total_num_frames = int(clip.duration * clip.fps) |
|
|
clip.close() |
|
|
sampling_interval = int(total_num_frames / max_images_length) |
|
|
if sampling_interval == 0: |
|
|
sampling_interval = 1 |
|
|
return sampling_interval,clip.fps |
|
|
def time_to_milliseconds(time_str): |
|
|
|
|
|
h, m, s = map(float, time_str.split(':')) |
|
|
return int((h * 3600 + m * 60 + s) * 1000) |
|
|
def extract_subtitles(subtitle_path): |
|
|
|
|
|
subtitles = [] |
|
|
for caption in webvtt.read(subtitle_path): |
|
|
start_ms = time_to_milliseconds(caption.start) |
|
|
end_ms = time_to_milliseconds(caption.end) |
|
|
text = caption.text.strip().replace('\n', ' ') |
|
|
subtitles.append((start_ms, end_ms, text)) |
|
|
return subtitles |
|
|
def find_subtitle(subtitles, frame_count, fps): |
|
|
frame_time = (frame_count / fps)*1000 |
|
|
|
|
|
left, right = 0, len(subtitles) - 1 |
|
|
|
|
|
while left <= right: |
|
|
mid = (left + right) // 2 |
|
|
start, end, subtitle_text = subtitles[mid] |
|
|
|
|
|
if start <= frame_time <= end: |
|
|
return subtitle_text |
|
|
elif frame_time < start: |
|
|
right = mid - 1 |
|
|
else: |
|
|
left = mid + 1 |
|
|
|
|
|
return None |
|
|
def match_frames_and_subtitles(video_path,subtitles,sampling_interval,max_sub_len,fps,max_frames): |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
images = [] |
|
|
frame_count = 0 |
|
|
img_placeholder = "" |
|
|
subtitle_text_in_interval = "" |
|
|
history_subtitles = {} |
|
|
number_of_words=0 |
|
|
transform=transforms.Compose([ |
|
|
transforms.ToPILImage(), |
|
|
]) |
|
|
|
|
|
|
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
if len (subtitles) > 0: |
|
|
|
|
|
frame_subtitle=find_subtitle(subtitles, frame_count, fps) |
|
|
if frame_subtitle and not history_subtitles.get(frame_subtitle,False): |
|
|
subtitle_text_in_interval+=frame_subtitle+" " |
|
|
history_subtitles[frame_subtitle]=True |
|
|
if frame_count % sampling_interval == 0: |
|
|
|
|
|
frame = transform(frame[:,:,::-1]) |
|
|
frame = vis_processor(frame) |
|
|
images.append(frame) |
|
|
img_placeholder += '<Img><ImageHere>' |
|
|
if subtitle_text_in_interval != "" and number_of_words< max_sub_len: |
|
|
img_placeholder+=f'<Cap>{subtitle_text_in_interval}' |
|
|
|
|
|
|
|
|
number_of_words+=len(subtitle_text_in_interval.split(' ')) |
|
|
subtitle_text_in_interval = "" |
|
|
|
|
|
frame_count += 1 |
|
|
if len(images) >= max_frames: |
|
|
break |
|
|
cap.release() |
|
|
cv2.destroyAllWindows() |
|
|
|
|
|
if len(images) == 0: |
|
|
|
|
|
return None,None |
|
|
images = torch.stack(images) |
|
|
return images,img_placeholder |
|
|
|
|
|
def prepare_input(video_path, subtitle_path,instruction): |
|
|
if "mistral" in args.ckpt : |
|
|
max_frames=90 |
|
|
max_sub_len = 800 |
|
|
else: |
|
|
max_frames = 45 |
|
|
max_sub_len = 400 |
|
|
sampling_interval,fps = extract_video_info(video_path, max_frames) |
|
|
subtitles = extract_subtitles(subtitle_path) |
|
|
frames_features,input_placeholder = match_frames_and_subtitles(video_path,subtitles,sampling_interval,max_sub_len,fps,max_frames) |
|
|
input_placeholder+="\n"+instruction |
|
|
return frames_features, input_placeholder |
|
|
|
|
|
|
|
|
def extract_audio(video_path, audio_path): |
|
|
video_clip = mp.VideoFileClip(video_path) |
|
|
audio_clip = video_clip.audio |
|
|
audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k") |
|
|
|
|
|
def get_subtitles(video_path) : |
|
|
audio_dir="workspace/inference_subtitles/mp3" |
|
|
subtitle_dir="workspace/inference_subtitles" |
|
|
os.makedirs(subtitle_dir, exist_ok=True) |
|
|
os.makedirs(audio_dir, exist_ok=True) |
|
|
video_id=video_path.split('/')[-1].split('.')[0] |
|
|
audio_path = f"workspace/inference_subtitles/mp3/{video_id}"+'.mp3' |
|
|
subtitle_path = f"{subtitle_dir}/{video_id}"+'.vtt' |
|
|
|
|
|
if os.path.exists(subtitle_path): |
|
|
return f"{subtitle_dir}/{video_id}"+'.vtt' |
|
|
audio_path = f"{audio_dir}/{video_id}"+'.mp3' |
|
|
try: |
|
|
extract_audio(video_path, audio_path) |
|
|
result = whisper_model.transcribe(audio_path,language="en") |
|
|
|
|
|
with open(subtitle_path, "w", encoding="utf-8") as vtt_file: |
|
|
vtt_file.write("WEBVTT\n\n") |
|
|
for segment in result['segments']: |
|
|
start = format_timestamp(segment['start']) |
|
|
end = format_timestamp(segment['end']) |
|
|
text = segment['text'] |
|
|
vtt_file.write(f"{start} --> {end}\n{text}\n\n") |
|
|
return subtitle_path |
|
|
except Exception as e: |
|
|
print(f"Error during subtitle generation for {video_path}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def stream_answer(generation_kwargs): |
|
|
streamer = TextIteratorStreamer(model.llama_tokenizer, skip_special_tokens=True) |
|
|
generation_kwargs['streamer'] = streamer |
|
|
thread = Thread(target=model_generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
return streamer |
|
|
def escape_markdown(text): |
|
|
|
|
|
md_chars = ['<', '>'] |
|
|
|
|
|
for char in md_chars: |
|
|
text = text.replace(char, '\\' + char) |
|
|
return text |
|
|
def model_generate(*args, **kwargs): |
|
|
|
|
|
with model.maybe_autocast(): |
|
|
output = model.llama_model.generate(*args, **kwargs) |
|
|
return output |
|
|
|
|
|
def generate_prediction (video_path,instruction,gen_subtitles=True,stream=True): |
|
|
if gen_subtitles: |
|
|
subtitle_path=get_subtitles(video_path) |
|
|
else : |
|
|
subtitle_path=None |
|
|
prepared_images,prepared_instruction=prepare_input(video_path,subtitle_path,instruction) |
|
|
if prepared_images is None: |
|
|
return "Video cann't be open ,check the video path again" |
|
|
length=len(prepared_images) |
|
|
prepared_images=prepared_images.unsqueeze(0) |
|
|
conv = CONV_VISION.copy() |
|
|
conv.system = "" |
|
|
|
|
|
conv.append_message(conv.roles[0], prepared_instruction) |
|
|
conv.append_message(conv.roles[1], None) |
|
|
prompt = [conv.get_prompt()] |
|
|
|
|
|
if stream: |
|
|
generation_kwargs = model.answer_prepare_for_streaming(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=1) |
|
|
streamer=stream_answer(generation_kwargs) |
|
|
print("Streamed answer:",end='') |
|
|
for a in streamer: |
|
|
print(a,end='') |
|
|
print() |
|
|
else: |
|
|
setup_seeds(50) |
|
|
answers = model.generate(prepared_images, prompt, max_new_tokens=args.max_new_tokens, do_sample=True, lengths=[length],num_beams=1) |
|
|
print("Generated_answer :",answers[0]) |
|
|
|
|
|
def get_arguments(): |
|
|
parser = argparse.ArgumentParser(description="Inference parameters") |
|
|
parser.add_argument("--cfg-path", help="path to configuration file.",default="test_configs/llama2_test_config.yaml") |
|
|
parser.add_argument("--ckpt", type=str,default='checkpoints/video_llama_checkpoint_last.pth', help="path to checkpoint") |
|
|
parser.add_argument("--add_subtitles",action= 'store_true',help="whether to add subtitles") |
|
|
parser.add_argument("--stream",action= 'store_true',help="whether to stream the answer") |
|
|
parser.add_argument("--question", type=str, help="question to ask") |
|
|
parser.add_argument("--video_path", type=str, help="Path to the video file") |
|
|
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens") |
|
|
parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") |
|
|
parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") |
|
|
parser.add_argument( |
|
|
"--options", |
|
|
nargs="+", |
|
|
help="override some settings in the used config, the key-value pair " |
|
|
"in xxx=yyy format will be merged into config file (deprecate), " |
|
|
"change to --cfg-options instead.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
args=get_arguments() |
|
|
def setup_seeds(seed): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
cudnn.benchmark = False |
|
|
cudnn.deterministic = True |
|
|
|
|
|
import yaml |
|
|
with open('test_configs/llama2_test_config.yaml') as file: |
|
|
config = yaml.load(file, Loader=yaml.FullLoader) |
|
|
seed=config['run']['seed'] |
|
|
print("seed",seed) |
|
|
model, vis_processor,whisper_gpu_id,minigpt4_gpu_id,answer_module_gpu_id = init_model(args) |
|
|
whisper_model=whisper.load_model("large").to(f"cuda:{whisper_gpu_id}") |
|
|
conv = CONV_VISION.copy() |
|
|
conv.system = "" |
|
|
if __name__ == "__main__": |
|
|
video_path=args.video_path |
|
|
instruction=args.question |
|
|
add_subtitles=args.add_subtitles |
|
|
stream=args.stream |
|
|
setup_seeds(50) |
|
|
t1=time.time() |
|
|
generate_prediction(video_path,instruction,gen_subtitles=add_subtitles,stream=stream) |
|
|
print("Time taken for inference",time.time()-t1) |