|
|
import argparse |
|
|
import traceback |
|
|
import random |
|
|
import re |
|
|
import copy |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import os |
|
|
import json, jsonlines |
|
|
from tqdm import tqdm |
|
|
import pdb |
|
|
import numpy as np |
|
|
import socket |
|
|
from vllm import LLM, SamplingParams |
|
|
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration |
|
|
from transformers import AutoProcessor, AutoTokenizer |
|
|
from qwen_vl_utils import process_vision_info, fetch_image |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from PIL import Image |
|
|
from datasets import load_dataset |
|
|
import sys |
|
|
sys.path.append('/mnt/beegfs/dzhu6/ViLaSR') |
|
|
from utils.edit_image import merge_bbox_movement, parse_bbox_and_movement, plot_movement, plot_bounding_boxes |
|
|
from typing import List, Dict, Any |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from datetime import datetime |
|
|
import time, csv |
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = """### Guidance: |
|
|
You are a spatial reasoning assistant with access to two powerful visualization tools. |
|
|
Your task is to break down complex spatial problems and iteratively refine your solution through visualization feedback. |
|
|
|
|
|
### Available tools: |
|
|
You can use the following two tools to visualize. After each tool usage, you must wait for and analyze the visualization feedback before proceeding. |
|
|
|
|
|
1. **Object Mapper** |
|
|
- Purpose: Identifies and maps key items in the space |
|
|
- Input format: JSON |
|
|
```json |
|
|
[{{ |
|
|
"index": i, # Image index |
|
|
"bbox_2d": [x1, y1, x2, y2], |
|
|
"label": "object name/description" |
|
|
}}] |
|
|
``` |
|
|
- Output: Generates bounding boxes for visual inspection of the i-th image |
|
|
|
|
|
2. **Path Tracer** |
|
|
- Purpose: Plots movement or connections between points |
|
|
- Input format: JSON |
|
|
```json |
|
|
[{{ |
|
|
"index": i, # Image index |
|
|
"start_point_2d": [x1, y1], |
|
|
"end_point_2d": [x2, y2], |
|
|
"label": "trace_description" |
|
|
}}] |
|
|
``` |
|
|
- Output: Generates visual paths for verification of the i-th image |
|
|
|
|
|
### Required Output Format: |
|
|
For each reasoning step, you must structure your response as follows: |
|
|
<think> [Your detailed reasoning process] </think> Action: [Object Mapper/Path Tracer] |
|
|
```json |
|
|
[JSON format coordinates] |
|
|
``` |
|
|
|
|
|
After your reasoning and iteratively refine your solution through visualization feedback, you should arrive at a final answer and structure your response as follows: |
|
|
<think> [Your detailed reasoning process] </think> Action: Answer |
|
|
<answer> [Your final answer] </answer> |
|
|
|
|
|
### Please NOTE the following reasoning techniques: |
|
|
1. Initial Analysis |
|
|
- Break down the spatial problem |
|
|
- Plan your approach |
|
|
|
|
|
2. Iterative Reasoning for Each Step |
|
|
- Choose appropriate tool |
|
|
- Provide absolute coordinates in JSON format (The top-left corner of the image is (0, 0) and the bottom-right corner is ({width}, {height})) |
|
|
- Observe the visualization output |
|
|
- Reflect on the visualization: |
|
|
* Is the placement/path accurate? |
|
|
* Does it align with your reasoning? |
|
|
* What adjustments are needed? |
|
|
- Backtrack and Adjust: |
|
|
* If errors found, backtrack to previous step to modify actions or decisions as needed""" |
|
|
|
|
|
PROMPT_TEMPLATE = """ |
|
|
### Question: |
|
|
{question} |
|
|
|
|
|
Begin your reasoning. After each tool use, critically evaluate the visualization and adjust if needed: |
|
|
""" |
|
|
|
|
|
BSZ=1 |
|
|
MAX_IMAGES=45 |
|
|
SUBIMAGE_PATTERN = r".*\#\#\#\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" |
|
|
TYPE_TEMPLATE = { |
|
|
|
|
|
"multiple choice": '\nAnswer with the option\'s letter from the given choices directly.', |
|
|
"free-form": '', |
|
|
"regression": '\nPlease answer the question using a single word or phrase (e.g., 42 or 3.14).', |
|
|
"numerical": '\nPlease answer the question using a single word or phrase (e.g., 42 or 3.14).', |
|
|
"vci": "", |
|
|
} |
|
|
|
|
|
from dataclasses import dataclass |
|
|
@dataclass |
|
|
class ProcessData: |
|
|
index: int |
|
|
response: str |
|
|
mm_data: Dict |
|
|
bbox_list_origin: Dict |
|
|
movement_list_origin: Dict |
|
|
finish_reason: str |
|
|
is_finished: bool |
|
|
grid_size: int |
|
|
|
|
|
def calculate_grid_centers(image_size=616, grid_size=5): |
|
|
"""for maze data""" |
|
|
|
|
|
|
|
|
margin_left = int(image_size * 0.125) |
|
|
margin_right = int(image_size * 0.1) |
|
|
margin_bottom = int(image_size * 0.11) |
|
|
margin_top = int(image_size * 0.12) |
|
|
|
|
|
usable_width = image_size - (margin_left + margin_right) |
|
|
usable_height = image_size - (margin_top + margin_bottom) |
|
|
|
|
|
cell_width = usable_width / grid_size |
|
|
cell_height = usable_height / grid_size |
|
|
|
|
|
centers = [] |
|
|
for i in range(grid_size): |
|
|
for j in range(grid_size): |
|
|
|
|
|
center_x = margin_left + cell_width/2 + j * cell_width |
|
|
center_y = margin_top + cell_height/2 + i * cell_height |
|
|
centers.append((center_x, center_y)) |
|
|
|
|
|
return centers, (cell_width+cell_height)/2 |
|
|
|
|
|
def check_path_tracer(movement_list, centers, cell_size): |
|
|
for movement in movement_list: |
|
|
for key in ['start_point_2d', 'end_point_2d']: |
|
|
x, y = int(movement[key][0]), int(movement[key][1]) |
|
|
min_distance = min([np.sqrt((x-c[0])**2 + (y-c[1])**2) for c in centers]) |
|
|
if min_distance > cell_size/2: |
|
|
|
|
|
return False |
|
|
return True |
|
|
|
|
|
def check_repetition(allindex, bbox_list_origin, movement_list_origin): |
|
|
for cnt, tmp_index in enumerate(allindex): |
|
|
for bbox_list in list(bbox_list_origin.values()): |
|
|
for bbox in bbox_list: |
|
|
if bbox in allindex[tmp_index]["bbox_list"]: |
|
|
return True |
|
|
for movement_list in list(movement_list_origin.values()): |
|
|
for movement in movement_list: |
|
|
if movement in allindex[tmp_index]["movement_list"]: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def process_single_response(data: ProcessData): |
|
|
"""处理单个响应的函数""" |
|
|
if data.is_finished is True: |
|
|
return { |
|
|
'index': data.index, |
|
|
'response': data.response, |
|
|
'finish_reason': data.finish_reason, |
|
|
'is_finished': data.is_finished, |
|
|
'processed_image_idx': [None], |
|
|
} |
|
|
try: |
|
|
|
|
|
bbox_list_new, movement_list_new = parse_bbox_and_movement(data.response) |
|
|
current_image_index = len(data.mm_data['image']) |
|
|
image_index_list, image_list = [], [] |
|
|
bbox_list, movement_list = data.bbox_list_origin, data.movement_list_origin |
|
|
finish_reason = None |
|
|
|
|
|
try: |
|
|
allindex = {} |
|
|
for tmp_bbox_list in bbox_list_new: |
|
|
tmp_bbox_list = copy.deepcopy(tmp_bbox_list) |
|
|
if tmp_bbox_list["index"] in allindex: |
|
|
if "bbox_list" in allindex[tmp_bbox_list["index"]]: |
|
|
allindex[tmp_bbox_list["index"]]["bbox_list"].append(tmp_bbox_list) |
|
|
else: |
|
|
allindex[tmp_bbox_list["index"]]["bbox_list"] = [tmp_bbox_list] |
|
|
else: |
|
|
allindex[tmp_bbox_list["index"]] = {'bbox_list': [tmp_bbox_list], 'movement_list': []} |
|
|
for tmp_movement_list in movement_list_new: |
|
|
tmp_movement_list = copy.deepcopy(tmp_movement_list) |
|
|
if tmp_movement_list["index"] in allindex: |
|
|
if "movement_list" in allindex[tmp_movement_list["index"]]: |
|
|
allindex[tmp_movement_list["index"]]["movement_list"].append(tmp_movement_list) |
|
|
else: |
|
|
allindex[tmp_movement_list["index"]]["movement_list"] = [tmp_movement_list] |
|
|
else: |
|
|
allindex[tmp_movement_list["index"]] = {'bbox_list': [], 'movement_list': [tmp_movement_list]} |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
print("bbox_list_new, movement_list_new: ", bbox_list_new, movement_list_new) |
|
|
finish_reason = "ToolGenError" |
|
|
|
|
|
if len(allindex) == 0: |
|
|
finish_reason = "ToolError" |
|
|
elif len(data.mm_data['image']) >= MAX_IMAGES+1: |
|
|
finish_reason = "TooManyImages" |
|
|
|
|
|
|
|
|
if finish_reason is not None: |
|
|
return { |
|
|
'index': data.index, |
|
|
'processed_image_idx': [None], |
|
|
'image': [data.mm_data['image'][0].copy()], |
|
|
'response': data.response, |
|
|
'finish_reason': finish_reason, |
|
|
'bbox_list': bbox_list, |
|
|
'movement_list': movement_list, |
|
|
'is_finished': True, |
|
|
} |
|
|
for cnt, tmp_index in enumerate(allindex): |
|
|
bbox_list_new, movement_list_new = allindex[tmp_index]["bbox_list"], allindex[tmp_index]["movement_list"] |
|
|
image_index_new = current_image_index + cnt |
|
|
image_index, bbox_list, movement_list = merge_bbox_movement( |
|
|
bbox_list_origin=data.bbox_list_origin, |
|
|
movement_list_origin=data.movement_list_origin, |
|
|
bbox_list_new=bbox_list_new, |
|
|
movement_list_new=movement_list_new, |
|
|
image_index_new=image_index_new, |
|
|
) |
|
|
image_index_list.append(image_index) |
|
|
if image_index == -1: |
|
|
return { |
|
|
'index': data.index, |
|
|
'processed_image_idx': [None], |
|
|
'image': [data.mm_data['image'][0].copy()], |
|
|
'response': data.response, |
|
|
'finish_reason': "ToolError", |
|
|
'bbox_list': bbox_list, |
|
|
'movement_list': movement_list, |
|
|
'is_finished': True, |
|
|
} |
|
|
image = data.mm_data['image'][image_index].copy() |
|
|
assert isinstance(image, Image.Image) |
|
|
input_width, input_height = image.size |
|
|
|
|
|
|
|
|
plot_bounding_boxes(image, bbox_list[image_index_new], input_height=input_height, input_width=input_width) |
|
|
plot_movement(image, movement_list[image_index_new], input_height=input_height, input_width=input_width) |
|
|
|
|
|
image_list.append(image) |
|
|
|
|
|
return { |
|
|
'index': data.index, |
|
|
'processed_image_idx': image_index_list, |
|
|
'image': image_list, |
|
|
'response': data.response, |
|
|
'finish_reason': data.finish_reason, |
|
|
'bbox_list': bbox_list, |
|
|
'movement_list': movement_list, |
|
|
'is_finished': data.is_finished |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error processing response {data.index}: {str(e)}") |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
|
|
|
def save_samples_info(samples_info, save_dir): |
|
|
|
|
|
def get_unique_dir(base_path, prefix='generation'): |
|
|
"""generate unique dirctory""" |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
counter = 0 |
|
|
while True: |
|
|
if counter == 0: |
|
|
dir_name = f"{prefix}_{timestamp}" |
|
|
else: |
|
|
dir_name = f"{prefix}_{timestamp}_{counter}" |
|
|
|
|
|
full_path = os.path.join(base_path, dir_name) |
|
|
if not os.path.exists(full_path): |
|
|
return full_path |
|
|
counter += 1 |
|
|
|
|
|
all_sample_dir = [] |
|
|
for idx, sample in enumerate(samples_info): |
|
|
if 'qid' in sample: |
|
|
sample_dir = os.path.join(save_dir, sample['qid']) |
|
|
elif 'id' in sample: |
|
|
sample_dir = os.path.join(save_dir, sample['id']) |
|
|
elif 'index' in sample: |
|
|
sample_dir = os.path.join(save_dir, sample['index']) |
|
|
else: |
|
|
sample_dir = get_unique_dir(save_dir, f'sample') |
|
|
os.makedirs(sample_dir, exist_ok=True) |
|
|
all_sample_dir.append(sample_dir) |
|
|
|
|
|
text_data = { |
|
|
'prompt': sample['prompt'], |
|
|
'sequence': sample['sequence'], |
|
|
'response': sample['response'], |
|
|
'finish_reason': sample['finish_reason'], |
|
|
'execution_pass': sample['execution_pass'] |
|
|
} |
|
|
|
|
|
with open(os.path.join(sample_dir, 'text_data.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(text_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
if 'multi_modal_data' in sample and 'image' in sample['multi_modal_data']: |
|
|
images_dir = os.path.join(sample_dir, 'images') |
|
|
os.makedirs(images_dir, exist_ok=True) |
|
|
|
|
|
for img_idx, img in enumerate(sample['multi_modal_data']['image']): |
|
|
if isinstance(img, Image.Image): |
|
|
img_path = os.path.join(images_dir, f'image_{img_idx}.png') |
|
|
img.save(img_path) |
|
|
return all_sample_dir |
|
|
|
|
|
|
|
|
def get_qwen_chat(model, processor, texts, images, sampling_params={}): |
|
|
responses = [] |
|
|
for text, image in zip(texts, images): |
|
|
inputs = processor(text=[text], images=[image], padding=True, return_tensors='pt').to(model.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
pad_token_id=processor.tokenizer.eos_token_id, |
|
|
temperature=sampling_params.temperature, |
|
|
top_p=sampling_params.top_p, |
|
|
max_new_tokens=sampling_params.max_tokens, |
|
|
) |
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
response = processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False, |
|
|
)[0] |
|
|
responses.append(response) |
|
|
return responses |
|
|
|
|
|
def multi_turn_generate(inference_engine, processor, tokenizer, vllm_inputs=None, sampling_params=None, prompt_token_ids=None, use_tqdm=False, save_dir=None, max_num_steps=10): |
|
|
def _get_prompts_and_indices(samples_info): |
|
|
prompts, multi_modal_data, indices=[], [], [] |
|
|
for index, info in enumerate(samples_info): |
|
|
if not info['stop'] and len(info['multi_modal_data']['image']) <= MAX_IMAGES: |
|
|
prompts.append(info['sequence']) |
|
|
multi_modal_data.append(info['multi_modal_data']) |
|
|
indices.append(info['index']) |
|
|
return prompts, multi_modal_data, indices |
|
|
|
|
|
sampling_params=copy.deepcopy(sampling_params) |
|
|
new_vllm_inputs = [] |
|
|
for single_vllm_input in vllm_inputs: |
|
|
prompt = tokenizer.decode(single_vllm_input['prompt_token_ids'], skip_special_tokens=False) |
|
|
new_vllm_inputs.extend([{ |
|
|
"id": single_vllm_input['id'], |
|
|
"prompt": prompt, |
|
|
"multi_modal_data": single_vllm_input['multi_modal_data'], |
|
|
"grid_size": single_vllm_input['grid_size'], |
|
|
} for _ in range(sampling_params.n)]) |
|
|
|
|
|
sampling_params.n=1 |
|
|
sampling_params.detokenize=True |
|
|
samples_info = [] |
|
|
for index, item in enumerate(new_vllm_inputs): |
|
|
|
|
|
processed_image = [fetch_image({'image': origin_image}) for origin_image in item['multi_modal_data']['image']] |
|
|
sample_info = { |
|
|
"id": item["id"], |
|
|
"prompt": item["prompt"], |
|
|
"sequence": item["prompt"], |
|
|
"multi_modal_data": {"image": processed_image}, |
|
|
"response": "", |
|
|
"stop": False, |
|
|
"finish_reason": None, |
|
|
"processed_image_idx": [], |
|
|
"index": index, |
|
|
"mask_info": [], |
|
|
"execution_pass": 0, |
|
|
"bbox_list": {img_idx: [] for img_idx in range(len(processed_image))}, |
|
|
"movement_list": {img_idx: [] for img_idx in range(len(processed_image))}, |
|
|
"grid_size": item['grid_size'], |
|
|
} |
|
|
samples_info.append(sample_info) |
|
|
intermediate_prompt = 'The index of the given image is {current_image_idx} (width: {width}, height: {height}). Continue your reasoning. After each tool use, critically evaluate the visualization and adjust if needed:' |
|
|
final_prompt = 'The index of the given image is {current_image_idx} (width: {width}, height: {height}). Then, you can not invoke the Object Mapper or Path Tracer tool. Please answer the initial question and structure your response as required:' |
|
|
intermediate_template = """<|im_end|> |
|
|
<|im_start|>user |
|
|
{pad} |
|
|
{prompt} |
|
|
<|im_end|> |
|
|
<|im_start|>assistant |
|
|
""" |
|
|
num_llm_calls_available = max_num_steps - 1 |
|
|
while num_llm_calls_available >= 0: |
|
|
num_llm_calls_available-=1 |
|
|
input_prompts, multi_modal_data, indices=_get_prompts_and_indices(samples_info) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if type(inference_engine) == LLM: |
|
|
input_prompts = [{ |
|
|
'prompt_token_ids': tokenizer.encode(prompt, add_special_tokens=False)[:], |
|
|
'multi_modal_data': mm_data |
|
|
} for prompt, mm_data in zip(input_prompts, multi_modal_data)] |
|
|
outputs = inference_engine.generate(prompts=input_prompts, sampling_params=sampling_params, use_tqdm=use_tqdm) |
|
|
else: |
|
|
responses = get_qwen_chat(inference_engine, processor, input_prompts, [m['image'] for m in multi_modal_data], sampling_params) |
|
|
class A: |
|
|
pass |
|
|
outputs = [] |
|
|
for request_id, response in enumerate(responses): |
|
|
a = A() |
|
|
a.outputs = [A()] |
|
|
a.outputs[0].text = response |
|
|
a.outputs[0].finish_reason = "stop" |
|
|
a.outputs[0].stop_reason = None |
|
|
a.request_id = request_id |
|
|
outputs.append(a) |
|
|
|
|
|
sorted_outputs = sorted(outputs, key=lambda output: int(output.request_id)) |
|
|
responses=[x.outputs[0].text for x in sorted_outputs] |
|
|
finish_reason=[x.outputs[0].finish_reason for x in sorted_outputs] |
|
|
stop_reason=[x.outputs[0].stop_reason for x in sorted_outputs] |
|
|
if num_llm_calls_available==-1: |
|
|
for i ,index in enumerate(indices): |
|
|
samples_info[index]['response']+=responses[i] |
|
|
samples_info[index]['sequence']+=responses[i] |
|
|
samples_info[index]['stop']=True |
|
|
samples_info[index]['finish_reason']=finish_reason[i] |
|
|
break |
|
|
|
|
|
def _is_finished(finish_reason, stop_reason, response): |
|
|
if finish_reason=='stop' and stop_reason==None and "<answer>" in response and "</answer>" in response: |
|
|
return True |
|
|
if finish_reason=='length': |
|
|
return True |
|
|
if finish_reason=='rule': |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
is_finished=[_is_finished(finish_reason[i], stop_reason[i], responses[i]) for i in range(len(finish_reason))] |
|
|
|
|
|
if all([x for x in is_finished]): |
|
|
for i ,index in enumerate(indices): |
|
|
samples_info[index]['response']+=responses[i] |
|
|
samples_info[index]['sequence']+=responses[i] |
|
|
samples_info[index]['stop']=True |
|
|
samples_info[index]['finish_reason']=finish_reason[i] |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
process_data_list = [ |
|
|
ProcessData( |
|
|
index=index, |
|
|
response=responses[i], |
|
|
mm_data=samples_info[index]['multi_modal_data'], |
|
|
bbox_list_origin=samples_info[index]["bbox_list"], |
|
|
movement_list_origin=samples_info[index]["movement_list"], |
|
|
finish_reason=finish_reason[i], |
|
|
is_finished=is_finished[i], |
|
|
grid_size=samples_info[index]['grid_size'], |
|
|
) for i, index in enumerate(indices)] |
|
|
with ThreadPoolExecutor(max_workers=max(min(len(indices), os.cpu_count()//2, 64), 1) ) as executor: |
|
|
results = list(executor.map(process_single_response, process_data_list)) |
|
|
|
|
|
|
|
|
for result in results: |
|
|
if result is not None: |
|
|
index = result['index'] |
|
|
samples_info[index]['response'] += result['response'] |
|
|
samples_info[index]['stop'] = result['is_finished'] |
|
|
samples_info[index]['finish_reason'] = result['finish_reason'] |
|
|
samples_info[index]['processed_image_idx'].extend(result['processed_image_idx']) |
|
|
if result['is_finished'] is False: |
|
|
current_image_count = len(samples_info[index]['multi_modal_data']['image']) |
|
|
if len(result["image"]) > 1: |
|
|
current_image_idx = current_image_count + 1 |
|
|
pad_prompt = "" |
|
|
for tmp_image_idx, tmp_image in enumerate(result["image"]): |
|
|
width, height = fetch_image({"image": tmp_image}).size |
|
|
if current_image_count + tmp_image_idx + 1>=MAX_IMAGES: |
|
|
pad_prompt += f"<|vision_start|><|image_pad|><|vision_end|>" + final_prompt.format( |
|
|
current_image_idx=current_image_idx + tmp_image_idx, |
|
|
width=width, |
|
|
height=height, |
|
|
) |
|
|
samples_info[index]['multi_modal_data']['image'].append(tmp_image) |
|
|
break |
|
|
else: |
|
|
if tmp_image_idx <= len(result["image"]) - 2: |
|
|
pad_prompt += f"<|vision_start|><|image_pad|><|vision_end|>The index of the given image is {current_image_idx+tmp_image_idx} (width: {width}, height: {height}).\n" |
|
|
else: |
|
|
if num_llm_calls_available > 0: |
|
|
pad_prompt += f"<|vision_start|><|image_pad|><|vision_end|>" + intermediate_prompt.format( |
|
|
current_image_idx=current_image_idx + tmp_image_idx, |
|
|
width=width, |
|
|
height=height, |
|
|
) |
|
|
else: |
|
|
pad_prompt += f"<|vision_start|><|image_pad|><|vision_end|>" + final_prompt.format( |
|
|
current_image_idx=current_image_idx + tmp_image_idx, |
|
|
width=width, |
|
|
height=height, |
|
|
) |
|
|
samples_info[index]['multi_modal_data']['image'].append(tmp_image) |
|
|
|
|
|
samples_info[index]['sequence'] += result['response'] + intermediate_template.format(prompt="", pad=pad_prompt) |
|
|
else: |
|
|
current_image_idx = current_image_count + 1 |
|
|
width, height = fetch_image({"image": result["image"][0]}).size |
|
|
if current_image_count + 1>= MAX_IMAGES: |
|
|
|
|
|
prompt = final_prompt.format( |
|
|
current_image_idx=current_image_idx, |
|
|
width=width, |
|
|
height=height, |
|
|
) |
|
|
else: |
|
|
|
|
|
prompt = (intermediate_prompt if num_llm_calls_available > 0 else final_prompt).format( |
|
|
current_image_idx=current_image_idx, |
|
|
width=width, |
|
|
height=height, |
|
|
) |
|
|
|
|
|
samples_info[index]['sequence'] += result['response'] + intermediate_template.format( |
|
|
prompt=prompt, |
|
|
pad="<|vision_start|><|image_pad|><|vision_end|>" |
|
|
) |
|
|
samples_info[index]['multi_modal_data']['image'].append(result['image'][0]) |
|
|
|
|
|
|
|
|
samples_info[index]['bbox_list'] = result['bbox_list'] |
|
|
samples_info[index]["movement_list"] = result['movement_list'] |
|
|
else: |
|
|
samples_info[index]['sequence'] += result['response'] |
|
|
|
|
|
for i, line in enumerate(samples_info): |
|
|
if samples_info[i]['finish_reason']!='length': |
|
|
samples_info[i]['sequence']+=tokenizer.eos_token |
|
|
|
|
|
batch_sequences = [sample['sequence'] for sample in samples_info] |
|
|
if save_dir: |
|
|
all_sample_dir = save_samples_info(samples_info, save_dir) |
|
|
return batch_sequences, all_sample_dir |
|
|
return batch_sequences |
|
|
|
|
|
|
|
|
def parse_dialog(serialized_content): |
|
|
|
|
|
segments = re.split(r'<\|im_start\|>|<\|im_end\|>', serialized_content) |
|
|
segments = [s for s in segments if s] |
|
|
|
|
|
conversations = [] |
|
|
current_role = None |
|
|
current_content = [] |
|
|
|
|
|
system_content = None |
|
|
if segments[0].startswith('system'): |
|
|
system_content = segments[0].replace('system\n\n', '', 1) |
|
|
segments = segments[1:] |
|
|
|
|
|
if system_content: |
|
|
conversations.append({ |
|
|
"role": "system", |
|
|
"content": system_content |
|
|
}) |
|
|
|
|
|
for segment in segments: |
|
|
if segment.startswith('user'): |
|
|
has_vision = '<|vision_start|><|image_pad|><|vision_end|>' in segment |
|
|
text = segment.replace('user\n', '', 1) |
|
|
|
|
|
|
|
|
content = [] |
|
|
if has_vision: |
|
|
content.append({ |
|
|
"type": "image", |
|
|
"image": "image_path", |
|
|
"nframes": "args.max_frames", |
|
|
"max_pixels": args.max_pixels |
|
|
}) |
|
|
content.append({ |
|
|
"type": "text", |
|
|
"text": text |
|
|
}) |
|
|
|
|
|
conversations.append({ |
|
|
"role": "user", |
|
|
"content": content |
|
|
}) |
|
|
elif segment.startswith('assistant'): |
|
|
text = segment.replace('assistant\n', '', 1) |
|
|
conversations.append({ |
|
|
"role": "assistant", |
|
|
"content": text |
|
|
}) |
|
|
|
|
|
return conversations |
|
|
|
|
|
|
|
|
def setup_distributed(): |
|
|
"""Setup distributed training environment for SLURM""" |
|
|
if "SLURM_PROCID" in os.environ: |
|
|
rank = int(os.environ["SLURM_PROCID"]) |
|
|
world_size = int(os.environ["SLURM_NTASKS"]) |
|
|
local_rank = int(os.environ["SLURM_LOCALID"]) |
|
|
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") |
|
|
master_port = os.environ.get("MASTER_PORT", "29500") |
|
|
else: |
|
|
rank = int(os.environ.get("RANK", 0)) |
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") |
|
|
master_port = os.environ.get("MASTER_PORT", "29500") |
|
|
|
|
|
os.environ["MASTER_ADDR"] = master_addr |
|
|
os.environ["MASTER_PORT"] = master_port |
|
|
|
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size) |
|
|
torch.cuda.set_device(local_rank) |
|
|
return rank, world_size, local_rank |
|
|
|
|
|
|
|
|
def eval_model(args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = args.model_path |
|
|
model_name = args.model_name |
|
|
|
|
|
print(f"prune: {args.prune}") |
|
|
print(f"resize: {args.resize}") |
|
|
print(f"Loading from {model_path}") |
|
|
print("torch.cuda.device_count():", torch.cuda.device_count()) |
|
|
if args.prune and 'vscan' in args.prune.lower(): |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
|
|
Qwen2_5_VisionTransformerPretrainedModel, |
|
|
Qwen2_5_VLVisionBlock, |
|
|
Qwen2_5_VLVisionSdpaAttention, |
|
|
Qwen2_5_VLVisionFlashAttention2, |
|
|
Qwen2_5_VisionPatchEmbed, |
|
|
Qwen2_5_VLModel |
|
|
) |
|
|
sys.path.append('/mnt/beegfs/dzhu6/SelfEvolvingAgent/VScan/') |
|
|
from qwen.model.qwen2_5_vl_custom import ( |
|
|
Qwen2_5_VLForConditionalGeneration_X, |
|
|
Qwen2_5_VisionTransformerPretrainedModel_X, |
|
|
Qwen2_5_VLVisionBlock_X, |
|
|
Qwen2_5_VLVisionSdpaAttention_X, |
|
|
Qwen2_5_VLVisionFlashAttention2_X, |
|
|
Qwen2_5_VisionPatchEmbed_X, |
|
|
Qwen2_5_VLModel_X |
|
|
) |
|
|
Qwen2_5_VLForConditionalGeneration.forward = Qwen2_5_VLForConditionalGeneration_X.forward |
|
|
Qwen2_5_VisionTransformerPretrainedModel.forward = Qwen2_5_VisionTransformerPretrainedModel_X.forward |
|
|
Qwen2_5_VLVisionBlock.forward = Qwen2_5_VLVisionBlock_X.forward |
|
|
Qwen2_5_VLVisionSdpaAttention.forward = Qwen2_5_VLVisionSdpaAttention_X.forward |
|
|
Qwen2_5_VLVisionFlashAttention2.forward = Qwen2_5_VLVisionFlashAttention2_X.forward |
|
|
Qwen2_5_VisionPatchEmbed.forward = Qwen2_5_VisionPatchEmbed_X.forward |
|
|
Qwen2_5_VLModel.forward = Qwen2_5_VLModel_X.forward |
|
|
Qwen2_5_VLModel.layer_prune = Qwen2_5_VLModel_X.layer_prune |
|
|
print("Qwen2_5_VLForConditionalGeneration.forward:", Qwen2_5_VLForConditionalGeneration.forward) |
|
|
llm = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
args.model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map=f"auto", |
|
|
attn_implementation="flash_attention_2", |
|
|
trust_remote_code=True |
|
|
).eval() |
|
|
llm.model.layer_list = [14] |
|
|
llm.model.image_token_ratio_list = [0.333] |
|
|
llm.image_token_ratio = 0.167 |
|
|
min_pixels = 4*28*28 |
|
|
max_pixels = 1280*28*28 |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_parallel_size = int(os.environ.get("SLURM_NTASKS", "1")) |
|
|
pipeline_parallel_size = int(os.environ.get("VLLM_PIPELINE_PARALLEL_SIZE", "1")) |
|
|
|
|
|
print(f"Loading from {model_path}, tensor_parallel_size: {tensor_parallel_size}, pipeline_parallel_size: {pipeline_parallel_size}") |
|
|
llm = LLM( |
|
|
trust_remote_code=True, |
|
|
model=model_path, |
|
|
dtype="bfloat16", |
|
|
tensor_parallel_size=tensor_parallel_size, |
|
|
pipeline_parallel_size=pipeline_parallel_size, |
|
|
limit_mm_per_prompt={"image": 62, "video": 10}, |
|
|
gpu_memory_utilization=0.85, |
|
|
enable_prefix_caching=True |
|
|
) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
|
sampling_params = SamplingParams( |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p, |
|
|
max_tokens=16384, |
|
|
stop_token_ids=[], |
|
|
) |
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
tokenizer.padding_side = "left" |
|
|
processor.tokenizer = tokenizer |
|
|
''' |
|
|
file_path = args.input_file |
|
|
if file_path.endswith('.jsonl'): |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
data = [json.loads(line) for line in f] |
|
|
else: |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
''' |
|
|
if 'blink' in args.dataset.lower(): |
|
|
sys.path.append('/mnt/beegfs/dzhu6/VisualSketchpad/') |
|
|
from infer2 import get_blink_dataset |
|
|
data = get_blink_dataset() |
|
|
elif 'spatialeval' in args.input_file.lower(): |
|
|
data = load_dataset("MilaWang/SpatialEval", "vqa", split="test") |
|
|
elif '3dsr' in args.dataset.lower(): |
|
|
dp = "/mnt/beegfs/dzhu6/correlation/3dsrbench_v1_vlmevalkit_circular.tsv" |
|
|
with open(dp, 'r') as f: |
|
|
reader = csv.DictReader(f, delimiter='\t') |
|
|
data = list(reader) |
|
|
|
|
|
st, ed = (len(data)*args.split)//args.all, (len(data)*(args.split+1))//args.all |
|
|
print(f"{len(data)} lines found, generating from {st} to {ed}") |
|
|
print("Data: ", len(data)) |
|
|
if type(data) is list: |
|
|
data = data[st:ed] |
|
|
else: |
|
|
data = data.select(range(st, ed)) |
|
|
print("Data: ", len(data)) |
|
|
batches = [] |
|
|
messages = [] |
|
|
ids = [] |
|
|
def get_key(x): |
|
|
return x['id'] if 'id' in x else x['qid'] if 'qid' in x else x['index'] |
|
|
|
|
|
start_idx = 0 |
|
|
old = set() |
|
|
if args.all > 1: |
|
|
|
|
|
output_dir = os.path.join(args.output_dir, f"split_{args.split}_all_{args.all}") |
|
|
else: |
|
|
output_dir = args.output_dir |
|
|
save_dir = output_dir |
|
|
if args.over_write: |
|
|
os.system(f"rm -rf {output_dir} && mkdir {output_dir}") |
|
|
else: |
|
|
if not os.path.exists(output_dir): |
|
|
os.system(f"mkdir {output_dir}") |
|
|
|
|
|
print("Output Dir: ", output_dir) |
|
|
output_file_path = f"{output_dir}/results.jsonl" |
|
|
if os.path.exists(output_file_path): |
|
|
mode = "a" |
|
|
with jsonlines.open(output_file_path) as fin: |
|
|
for line in fin: |
|
|
old.add(get_key(line)) |
|
|
|
|
|
|
|
|
else: |
|
|
mode = "w" |
|
|
print("loaded ", len(old), " lines from ", output_file_path) |
|
|
|
|
|
for xidx, x in enumerate(data): |
|
|
if get_key(x) in old: |
|
|
continue |
|
|
|
|
|
|
|
|
ids.append(get_key(x)) |
|
|
batches.append(x.copy()) |
|
|
batches[-1].pop("image", None) |
|
|
|
|
|
|
|
|
''' |
|
|
{ |
|
|
"dataset": "SpatialEval", |
|
|
"task": "spatialreal", |
|
|
"split": "test", |
|
|
"question_id": "spatialreal.vqa.sa_1543979.0", |
|
|
"question": "Please answer the following question based on the image. How many individual cartons of strawberries are there in the wooden basket? Available options:\nA. eight\nB. ten\nC. four\nD. six\n", |
|
|
"answer": "D", |
|
|
"image_path": [ |
|
|
"SpatialEval/spatialreal.vqa.sa_1543979.0.png" |
|
|
], |
|
|
"data_type": "image", |
|
|
"problem_type": "multiple choice" |
|
|
}, |
|
|
''' |
|
|
if args.dataset in ["vsi_bench"]: |
|
|
prompt = f"These are frames from a video, numbered from 1 to {args.max_frames} in sequence. That is, the index of each image is 1, 2, 3, ..., {args.max_frames}.\n\nAnswer the quesntion with appropriate tools:\n" + x['question'] |
|
|
if x['problem_type'] == 'multiple choice' and 'options' in x: |
|
|
prompt = prompt + '\n' + '\n'.join(x['options']) |
|
|
prompt = prompt + TYPE_TEMPLATE[x['problem_type'].lower()] |
|
|
width, height = fetch_image({"image": os.path.join(args.image_folder, x["image_path"][0]), "max_pixels": args.max_pixels}).size |
|
|
image_messages = [] |
|
|
for image_idx, image_path in enumerate(x["image_path"]): |
|
|
image_messages.extend([ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": os.path.join(args.image_folder, image_path), |
|
|
"nframes": args.max_frames, |
|
|
"max_pixels": args.max_pixels |
|
|
}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": f"The index of the given image is {image_idx+1} (width: {width}, height: {height}).\n", |
|
|
} |
|
|
]) |
|
|
image_messages.append({ |
|
|
"type": "text", |
|
|
"text": PROMPT_TEMPLATE.format(question=prompt) |
|
|
}) |
|
|
msg = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": SYSTEM_PROMPT.format(width=width, height=height) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": image_messages, |
|
|
} |
|
|
] |
|
|
elif "3dsrbench" in args.dataset.lower(): |
|
|
x['index'] = x['qid'] |
|
|
question = x['question'] |
|
|
question += f"\nA. {x['A']}" |
|
|
question += f"\nB. {x['B']}" |
|
|
if "C" in x and x['C'] is not None and x['C'].strip() != "": |
|
|
question += f"\nC. {x['C']}" |
|
|
if "D" in x and x['D'] is not None and x['D'].strip() != "": |
|
|
question += f"\nD. {x['D']}" |
|
|
question += "\n" |
|
|
prompt = question |
|
|
|
|
|
if "flip" in x['index'].lower(): |
|
|
if "flip" in x['index'].lower(): |
|
|
ext = x['image_url'].split('.')[-1] |
|
|
flip = x['image_url'].replace("http://images.cocodataset.org/", "/mnt/beegfs/dzhu6/coco_images/").replace("." + ext, "_flip." + ext) |
|
|
if not os.path.exists(flip): |
|
|
image = Image.open(x['image_url'].replace("http://images.cocodataset.org/", "/mnt/beegfs/dzhu6/coco_images/")).convert("RGB") |
|
|
image = image.transpose(method=Image.FLIP_LEFT_RIGHT) |
|
|
image.save(flip) |
|
|
x['image_url'] = x['image_url'].replace('.' + ext, "_flip." + ext) |
|
|
|
|
|
x["problem_type"] = "multiple choice" |
|
|
x['image_path'] = [x['image_url'].replace("http://images.cocodataset.org/", "")] |
|
|
prompt = prompt + '\nThe index of the given image is 1.' + TYPE_TEMPLATE[x['problem_type'].lower()] |
|
|
|
|
|
width, height = fetch_image({"image": os.path.join(args.image_folder, x["image_path"][0]), "max_pixels": args.max_pixels}).size |
|
|
if args.resize: |
|
|
width, height = width//2, height//2 |
|
|
msg = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": SYSTEM_PROMPT.format(width=width, height=height) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": os.path.join(args.image_folder, x["image_path"][0]), |
|
|
"nframes": args.max_frames, |
|
|
"grid_size": x["grid_size"] if "grid_size" in x else None, |
|
|
"max_pixels": args.max_pixels |
|
|
}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": PROMPT_TEMPLATE.format(question=prompt) |
|
|
} |
|
|
] |
|
|
}] |
|
|
elif "blink" in args.dataset.lower(): |
|
|
image_inputs, _ = process_vision_info(x['messages']) |
|
|
if image_inputs: |
|
|
width, height = image_inputs[0].size |
|
|
if args.resize: |
|
|
width, height = width//2, height//2 |
|
|
msg = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": SYSTEM_PROMPT.format(width=width, height=height) |
|
|
}, |
|
|
] |
|
|
else: |
|
|
msg = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": SYSTEM_PROMPT.format(width=0, height=0) |
|
|
}, |
|
|
] |
|
|
msg += x['messages'] |
|
|
elif 'spatialeval' in args.dataset.lower(): |
|
|
x["problem_type"] = "multiple choice" |
|
|
x['image_path'] = [x['id'] + '.png'] |
|
|
prompt = x["text"] |
|
|
prompt = prompt + '\nThe index of the given image is 1.' + TYPE_TEMPLATE[x['problem_type'].lower()] |
|
|
width, height = fetch_image({"image": os.path.join(args.image_folder, x["image_path"][0]), "max_pixels": args.max_pixels}).size |
|
|
if args.resize: |
|
|
width, height = width//2, height//2 |
|
|
msg = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": SYSTEM_PROMPT.format(width=width, height=height) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": os.path.join(args.image_folder, x["image_path"][0]), |
|
|
"nframes": args.max_frames, |
|
|
"grid_size": x["grid_size"] if "grid_size" in x else None, |
|
|
"max_pixels": args.max_pixels |
|
|
}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": PROMPT_TEMPLATE.format(question=prompt) |
|
|
} |
|
|
] |
|
|
}] |
|
|
elif args.dataset in ["maze", "SpatialEval_spatialreal",]: |
|
|
prompt = x["question"] |
|
|
if x['problem_type'] == 'multiple choice' and 'options' in x: |
|
|
prompt = prompt + '\n' + '\n'.join(x['options']) |
|
|
prompt = prompt + '\nThe index of the given image is 1.' + TYPE_TEMPLATE[x['problem_type'].lower()] |
|
|
width, height = fetch_image({"image": os.path.join(args.image_folder, x["image_path"][0]), "max_pixels": args.max_pixels}).size |
|
|
msg = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": SYSTEM_PROMPT.format(width=width, height=height) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": os.path.join(args.image_folder, x["image_path"][0]), |
|
|
"nframes": args.max_frames, |
|
|
"grid_size": x["grid_size"] if "grid_size" in x else None, |
|
|
"max_pixels": args.max_pixels |
|
|
}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": PROMPT_TEMPLATE.format(question=prompt) |
|
|
} |
|
|
] |
|
|
}] |
|
|
elif args.dataset in ["spar_bench", "spar_bench_tiny", "mmsi_bench"]: |
|
|
prompt = x["question"] |
|
|
if x['problem_type'] == 'multiple choice' and x.get('options', None) is not None: |
|
|
prompt = prompt + '\n' + '\n'.join(x['options']) |
|
|
prompt = prompt.replace("Your answer can only include one of options A, B, C or D.", "") |
|
|
prompt = prompt.replace("Answer using a single number and nothing else.", "") |
|
|
|
|
|
post_prompt = "" |
|
|
if x.get('original_question_type', None) in ['position_matching', "camera_motion_infer"]: |
|
|
post_prompt = "The values represent the bounding box coordinates normalized to a 0-1000 scale, with the top-left corner as the origin of the image." |
|
|
prompt = prompt + "\n" + post_prompt |
|
|
|
|
|
if x['data_type'] == 'single_view': |
|
|
prompt = prompt + '\nThe index of the given image is 1.' + TYPE_TEMPLATE[x['problem_type'].lower()] |
|
|
width, height = fetch_image({"image": os.path.join(args.image_folder, x["image_path"][0]), "max_pixels": args.max_pixels}).size |
|
|
msg = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": SYSTEM_PROMPT.format(width=width, height=height) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": os.path.join(args.image_folder, x["image_path"][0]), |
|
|
"max_pixels": args.max_pixels |
|
|
}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": PROMPT_TEMPLATE.format(question=prompt) |
|
|
} |
|
|
] |
|
|
} |
|
|
] |
|
|
elif x['data_type'] == 'multi_view': |
|
|
n_frames = len(x["image_path"]) |
|
|
width, height = fetch_image({"image": os.path.join(args.image_folder, x["image_path"][0]), "max_pixels": args.max_pixels}).size |
|
|
image_messages = [] |
|
|
for image_idx, image_path in enumerate(x["image_path"]): |
|
|
image_messages.extend([ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": os.path.join(args.image_folder, image_path), |
|
|
"max_pixels": args.max_pixels |
|
|
}, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": f"The index of the given image is {image_idx+1} (width: {width}, height: {height}).\n" |
|
|
} |
|
|
]) |
|
|
prompt = prompt + TYPE_TEMPLATE[x['problem_type'].lower()] |
|
|
image_messages.append({ |
|
|
"type": "text", |
|
|
"text": PROMPT_TEMPLATE.format(question=prompt) |
|
|
}) |
|
|
|
|
|
msg = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": SYSTEM_PROMPT.format(width=width, height=height) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": image_messages |
|
|
} |
|
|
] |
|
|
else: |
|
|
raise Exception(f"UNKNON args.dataset: {args.dataset}") |
|
|
messages.append(msg) |
|
|
print("messages: ", len(messages)) |
|
|
print("batches: ", len(batches)) |
|
|
|
|
|
|
|
|
with open(output_file_path, mode, encoding="utf-8") as fout: |
|
|
print("Message Example:", messages[0]) |
|
|
print(f"Start from the {start_idx} example") |
|
|
for i in tqdm(range(start_idx, len(messages), BSZ), desc="Processing batches"): |
|
|
batch_messages = messages[i:i + BSZ] |
|
|
prompts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] |
|
|
image_num = [] |
|
|
for msg in batch_messages: |
|
|
current_image_num = 0 |
|
|
for turn in msg: |
|
|
if isinstance(turn["content"], list): |
|
|
for turn_content in turn["content"]: |
|
|
if turn_content["type"] == "image": |
|
|
current_image_num += 1 |
|
|
if args.dataset in ["vsi_bench"]: |
|
|
assert current_image_num == args.max_frames, f"wrong image number: {current_image_num} != {args.max_frames}" |
|
|
elif args.dataset in ["maze", "SpatialEval_spatialreal", "SpatialEval"]: |
|
|
assert current_image_num == 1, f"wrong image number: {current_image_num}" |
|
|
image_num.append(current_image_num) |
|
|
image_inputs, video_inputs, video_kwargs = process_vision_info(batch_messages, return_video_kwargs=True) |
|
|
image_idx = 0 |
|
|
video_idx = 0 |
|
|
llm_inputs = [] |
|
|
for idx, (prompt, msg) in enumerate(zip(prompts, batch_messages)): |
|
|
mm_type = batch_messages[idx][1]['content'][0]['type'] |
|
|
sample_mm_data = {} |
|
|
sample_video_kw = {} |
|
|
if mm_type == 'image': |
|
|
sample_mm_data["image"] = [] |
|
|
for current_idx in range(image_num[idx]): |
|
|
width, height = image_inputs[image_idx].size |
|
|
image = image_inputs[image_idx] |
|
|
if args.resize: |
|
|
image = image.resize((width//2, height//2), resample=Image.Resampling.LANCZOS) |
|
|
if args.dataset in ["video", "vsi_bench"]: |
|
|
sample_mm_data["image"].append(image) |
|
|
else: |
|
|
sample_mm_data["image"].append(image) |
|
|
image_idx += 1 |
|
|
elif mm_type == 'video': |
|
|
sample_mm_data["video"] = [video_inputs[video_idx]] |
|
|
for key, value in video_kwargs.items(): |
|
|
sample_video_kw[key] = value[video_idx] |
|
|
video_idx += 1 |
|
|
llm_inputs.append({ |
|
|
'id': ids[i + idx], |
|
|
"prompt": prompt, |
|
|
"prompt_token_ids": tokenizer.encode(prompt, add_special_tokens=False), |
|
|
"multi_modal_data": sample_mm_data, |
|
|
"mm_processor_kwargs": sample_video_kw, |
|
|
"grid_size": msg[1]["content"][0]["grid_size"] if args.dataset == 'maze' else None |
|
|
}) |
|
|
if image_inputs is not None: |
|
|
assert image_idx == len(image_inputs), f"Image index mismatch: {image_idx} != {len(image_inputs)}" |
|
|
if video_inputs is not None: |
|
|
assert video_idx == len(video_inputs), f"Video index mismatch: {video_idx} != {len(video_inputs)}" |
|
|
|
|
|
try: |
|
|
if i < 1e9: |
|
|
batch_sequences = multi_turn_generate(llm, processor, tokenizer, vllm_inputs=llm_inputs, sampling_params=sampling_params, save_dir=save_dir, |
|
|
max_num_steps=20 if args.dataset=="maze" else 10) |
|
|
batch_sequences, all_sample_dir = batch_sequences |
|
|
else: |
|
|
batch_sequences = multi_turn_generate(llm, processor, tokenizer, vllm_inputs=llm_inputs, sampling_params=sampling_params, save_dir=None, |
|
|
max_num_steps=20 if args.dataset=="maze" else 10) |
|
|
all_sample_dir = [None] * len(batch_sequences) |
|
|
batch_conversations = [parse_dialog(sequence) for sequence in batch_sequences] |
|
|
print(f"Processed batch {(i)//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}. ") |
|
|
except Exception as e: |
|
|
print(f"Error processing batch starting at index {i}: {e}") |
|
|
continue |
|
|
|
|
|
for input_example, model_output, sample_dir in zip(batches[i:i + BSZ], batch_conversations, all_sample_dir): |
|
|
result = input_example.copy() |
|
|
result['conversations'] = model_output |
|
|
result['model_output'] = model_output[-1]['content'] |
|
|
result['model_id'] = model_name |
|
|
result['sample_dir'] = sample_dir |
|
|
|
|
|
fout.write( |
|
|
json.dumps(result) |
|
|
+ "\n" |
|
|
) |
|
|
fout.flush() |
|
|
if 'spatialeval' in args.dataset.lower(): |
|
|
with jsonlines.open(f"/mnt/beegfs/dzhu6/SpatialEval/outputs/vqa/spatialreal/m-{model_name}_bare_prune_{args.prune}{'_resize_True' if args.resize else ''}_split_{args.split}_all_{args.all}.jsonl", "a") as writer: |
|
|
result = input_example.copy() |
|
|
result['answer'] = model_output[-1]['content'] |
|
|
writer.write(result) |
|
|
elif '3dsr' in args.dataset.lower(): |
|
|
with jsonlines.open(f"/mnt/beegfs/dzhu6/correlation/results/{model_name}_prune_{args.prune}{'_resize_True' if args.resize else ''}_split_{args.split}_all_{args.all}.jsonl", "a") as writer: |
|
|
result = input_example.copy() |
|
|
result['response'] = model_output[-1]['content'] |
|
|
writer.write(result) |
|
|
elif 'blink' in args.dataset.lower(): |
|
|
with jsonlines.open(f"/mnt/beegfs/dzhu6/VisualSketchpad/{model_name}_traj_False_prune_{args.prune}{'_resize_True' if args.resize else ''}_maxnewtokens_16384_split_{args.split}_all_{args.all}.jsonl", "a") as writer: |
|
|
result = input_example.copy() |
|
|
result.pop('messages', None) |
|
|
result['response'] = model_output[-1]['content'] |
|
|
writer.write(result) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model-path", type=str, required=True) |
|
|
parser.add_argument("--model-name", type=str, required=True) |
|
|
parser.add_argument("--model-base", type=str, default=None) |
|
|
parser.add_argument("--dataset", type=str, required=True, help="") |
|
|
parser.add_argument("--image-folder", type=str, default="") |
|
|
parser.add_argument("--input-file", type=str, required=True, help="Path to the question file") |
|
|
parser.add_argument("--output-dir", type=str, default="./result") |
|
|
parser.add_argument("--temperature", type=float, default=0.75) |
|
|
parser.add_argument("--top_p", type=float, default=0.9) |
|
|
parser.add_argument("--num_beams", type=int, default=1) |
|
|
parser.add_argument("--max-frames", type=int, default=32) |
|
|
parser.add_argument("--max-pixels", type=int, default=256*28*28) |
|
|
parser.add_argument("--over_write", type=int, default=0, help="Whether to overwrite the output directory") |
|
|
parser.add_argument("--split", type=int, default=1) |
|
|
parser.add_argument("--all", type=int, default=1) |
|
|
parser.add_argument("--prune", type=str, default=None) |
|
|
parser.add_argument("--resize", action="store_true", default=False) |
|
|
args = parser.parse_args() |
|
|
if args.image_folder == "None": |
|
|
args.image_folder = "" |
|
|
eval_model(args) |
|
|
|
|
|
|