| import json | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from infer_utils import run_inference_single | |
| import numpy as np | |
| def run_geochat_inference( | |
| model, | |
| dataset_path, | |
| processor, | |
| tokenizer, | |
| conv_mode, | |
| answer_path, | |
| use_video_data=False, | |
| open_prompt=None, | |
| repeat_frames=None, | |
| prompt_strategy="interleave", | |
| chronological_prefix=True, | |
| data_frac=1, | |
| data_size=None, | |
| delete_system_prompt=False, | |
| start_ind=0, | |
| end_ind=None, | |
| print_prompt=False | |
| ): | |
| with open(dataset_path) as f: | |
| qfabric_data = json.load(f) | |
| if data_size is not None: | |
| data_size = min(data_size, len(qfabric_data)) | |
| idx = np.random.choice(len(qfabric_data), data_size, replace=False) | |
| qfabric_data = [qfabric_data[i] for i in idx] | |
| elif data_frac < 1: | |
| idx = np.random.choice(len(qfabric_data), int(len(qfabric_data) * data_frac), replace=False) | |
| qfabric_data = [qfabric_data[i] for i in idx] | |
| answers = {} | |
| answers_tmp = str(answer_path).replace(".json", "_tmp.json") | |
| if end_ind is not None: | |
| answers_tmp = str(answers_tmp).replace(".json", f"_{start_ind}_{end_ind}.json") | |
| qfabric_data = qfabric_data[start_ind:end_ind] | |
| else: | |
| answers_tmp = str(answers_tmp).replace(".json", f"_{start_ind}_end.json") | |
| qfabric_data = qfabric_data[start_ind:] | |
| print("answers_tmp: ", answers_tmp) | |
| print("start ind: ", start_ind) | |
| print("end ind: ", end_ind) | |
| for question in tqdm(qfabric_data): | |
| question_id = question["id"] | |
| inp = question["conversations"][0]['value'] | |
| answer_str = question["conversations"][1]['value'] | |
| metadata = question['metadata'] | |
| image_paths = question['video'] | |
| task = question['task'] | |
| original_input_polygon = question['original_input_polygon'] | |
| dataset = question['dataset'] | |
| outputs = run_inference_single( | |
| model=model, | |
| processor=processor, | |
| tokenizer=tokenizer, | |
| conv_mode=conv_mode, | |
| inp=inp, | |
| image_paths=image_paths, | |
| metadata=metadata, | |
| repeat_frames=repeat_frames, | |
| use_video_data=use_video_data, | |
| prompt_strategy=prompt_strategy, | |
| chronological_prefix=chronological_prefix, | |
| delete_system_prompt=delete_system_prompt, | |
| print_prompt=print_prompt | |
| ) | |
| entry = { | |
| "id": question_id, | |
| "question": inp, | |
| "predicted": outputs, | |
| "ground_truth": answer_str, | |
| "task": task, | |
| "original_input_polygon": original_input_polygon, | |
| "dataset": dataset, | |
| } | |
| answers[question_id] = entry | |
| with open(answers_tmp, "a") as f: | |
| f.write(json.dumps(entry) + "\n") | |
| return answers | |