|
|
from operator import attrgetter |
|
|
from llava.model.builder import load_pretrained_model |
|
|
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token |
|
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX |
|
|
from llava.conversation import conv_templates, SeparatorStyle |
|
|
|
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import requests |
|
|
import copy |
|
|
import warnings |
|
|
from decord import VideoReader, cpu |
|
|
|
|
|
import re |
|
|
from tqdm import tqdm |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
import torch.multiprocessing as mp |
|
|
mp.set_start_method('spawn', force=True) |
|
|
|
|
|
pretrained = "your_llava-ov_checkpoint" |
|
|
|
|
|
model_name = "llava_qwen" |
|
|
llava_model_args = { |
|
|
"multimodal": True, |
|
|
} |
|
|
|
|
|
def load_model(device): |
|
|
"""Initialize model on a specific GPU.""" |
|
|
tokenizer, model, image_processor, max_length = load_pretrained_model( |
|
|
pretrained, None, model_name, device_map={"": device}, attn_implementation="sdpa", **llava_model_args |
|
|
) |
|
|
model.eval() |
|
|
return tokenizer, model, image_processor |
|
|
|
|
|
def load_video(video_path, max_frames_num): |
|
|
"""Function to extract frames from video""" |
|
|
if isinstance(video_path, str): |
|
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
|
else: |
|
|
vr = VideoReader(video_path[0], ctx=cpu(0)) |
|
|
total_frame_num = len(vr) |
|
|
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) |
|
|
frame_idx = uniform_sampled_frames.tolist() |
|
|
spare_frames = vr.get_batch(frame_idx).asnumpy() |
|
|
return spare_frames |
|
|
|
|
|
def split_list(lst, num_splits): |
|
|
"""Split a list into evenly sized chunks""" |
|
|
chunk_size = len(lst) // num_splits |
|
|
remainder = len(lst) % num_splits |
|
|
chunks = [] |
|
|
start = 0 |
|
|
for i in range(num_splits): |
|
|
end = start + chunk_size + (1 if i < remainder else 0) |
|
|
chunks.append(lst[start:end]) |
|
|
start = end |
|
|
return chunks |
|
|
|
|
|
def process_chunk(gpu_id, chunk, results): |
|
|
"""Process a chunk of data on a specific GPU.""" |
|
|
device = f"cuda:{gpu_id}" |
|
|
tokenizer, model, image_processor = load_model(device) |
|
|
model = model.to(device) |
|
|
|
|
|
res_l = [] |
|
|
for sample in tqdm(chunk, desc=f"GPU {gpu_id}"): |
|
|
try: |
|
|
|
|
|
video_path = sample[0] |
|
|
video_frames = load_video(video_path, 32) |
|
|
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().to(device) |
|
|
|
|
|
|
|
|
conv_template = "qwen_1_5" |
|
|
|
|
|
question = f"{DEFAULT_IMAGE_TOKEN}\nAnalyze the video content and evaluate the performance for the following 20 skills. Assign a ranking to each skill on a scale of 0 to 4, where 0 represents no impact or poor performance and 4 represents excellent performance. Follow the format: `Skill Name: Ranking` (e.g., `Free Throw: 4`). The skills to evaluate are: Free Throw, 2-PTs, 3-PTs, Contested-shots, Overall Shooting, Rebounds, Defensive Rebounds, Offensive Rebounds, Steals, Turnovers, Points-allowed, Defensive Consistency, Assists, Passing Accuracy, Fouls, Contribution, Offensive Consistency, Teamwork, Impact, Efficiency" |
|
|
|
|
|
conv = copy.deepcopy(conv_templates[conv_template]) |
|
|
conv.append_message(conv.roles[0], question) |
|
|
conv.append_message(conv.roles[1], None) |
|
|
prompt_question = conv.get_prompt() |
|
|
|
|
|
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) |
|
|
image_sizes = [frame.size for frame in video_frames] |
|
|
|
|
|
|
|
|
cont = model.generate( |
|
|
input_ids, |
|
|
images=[frames], |
|
|
image_sizes=image_sizes, |
|
|
do_sample=False, |
|
|
temperature=0, |
|
|
max_new_tokens=4096, |
|
|
modalities=["video"], |
|
|
) |
|
|
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) |
|
|
|
|
|
src_txt = text_outputs[0].replace('\n', ',') |
|
|
|
|
|
src_num = re.findall(r':\s*(\d+)', src_txt) |
|
|
src_num = [int(num) for num in src_num] |
|
|
if len(src_num) != 20: |
|
|
print(f'output skill num mismatch! {len(src_num)}') |
|
|
|
|
|
tgt_num = [int(x) for x in sample[1:]] |
|
|
|
|
|
matches = sum(1 for x, y in zip(src_num, tgt_num) if x == y) / len(tgt_num) |
|
|
res_l.append(matches) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error on GPU {gpu_id}: {e}") |
|
|
continue |
|
|
|
|
|
results[gpu_id] = res_l |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_path = 'your_BASKET_test_path' |
|
|
|
|
|
test_list = [] |
|
|
with open(test_path, 'r') as f: |
|
|
for line in f.readlines()[1:-2]: |
|
|
test_list.append(line.strip().split(',')) |
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
|
if num_gpus == 0: |
|
|
raise RuntimeError("No GPUs available") |
|
|
|
|
|
test_chunks = split_list(test_list, num_gpus) |
|
|
|
|
|
manager = mp.Manager() |
|
|
results = manager.dict() |
|
|
processes = [] |
|
|
|
|
|
for gpu_id, chunk in enumerate(test_chunks): |
|
|
p = mp.Process(target=process_chunk, args=(gpu_id, chunk, results)) |
|
|
processes.append(p) |
|
|
p.start() |
|
|
|
|
|
for p in processes: |
|
|
p.join() |
|
|
|
|
|
|
|
|
all_results = [] |
|
|
for res in results.values(): |
|
|
all_results.extend(res) |
|
|
|
|
|
average = sum(all_results) / len(all_results) |
|
|
print(f'Average accuracy out of 20: {average}') |
|
|
|