|
|
import os |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
from lmms_eval.api.registry import ALL_TASKS |
|
|
from lmms_eval.tasks import ( |
|
|
ConfigurableTask, |
|
|
get_task_dict, |
|
|
include_path, |
|
|
initialize_tasks, |
|
|
) |
|
|
|
|
|
|
|
|
def rank0_print(*args): |
|
|
if dist.is_initialized(): |
|
|
if dist.get_rank() == 0: |
|
|
print(f"Rank {dist.get_rank()}: ", *args) |
|
|
else: |
|
|
print(*args) |
|
|
|
|
|
|
|
|
class BaseEmbedder(ABC): |
|
|
def __init__(self, name: str, output_path: str) -> None: |
|
|
super().__init__() |
|
|
self.name = name |
|
|
self.output_path = output_path |
|
|
os.makedirs(self.output_path, exist_ok=True) |
|
|
initialize_tasks() |
|
|
|
|
|
def flatten(self, input): |
|
|
new_list = [] |
|
|
for i in input: |
|
|
for j in i: |
|
|
new_list.append(j) |
|
|
return new_list |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def init_task(task: str, ignored_ids: Union[set, List] = None): |
|
|
task_dict = get_task_dict([task], model_name="llava") |
|
|
task_obj = task_dict[task] |
|
|
if type(task_obj) == tuple: |
|
|
group, task_obj = task_obj |
|
|
DATASET_PATH = task_obj.DATASET_PATH |
|
|
DATASET_NAME = None |
|
|
if task_obj.DATASET_NAME is not None: |
|
|
DATASET_NAME = task_obj.DATASET_NAME |
|
|
|
|
|
docs = task_obj.test_docs() if task_obj.has_test_docs() else task_obj.validation_docs() |
|
|
split = task_obj.config.test_split if task_obj.has_test_docs() else task_obj.config.validation_split |
|
|
rank0_print(f"\nTask : {task_obj.config.task}\n - #num : {len(task_obj.test_docs()) if task_obj.has_test_docs() else task_obj.validation_docs()}") |
|
|
task_obj.build_all_requests() |
|
|
requests = [] |
|
|
for instance in task_obj.instances: |
|
|
reqtype = instance.request_type |
|
|
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = instance.args |
|
|
if ignored_ids is not None and doc_id in ignored_ids: |
|
|
continue |
|
|
requests.append(instance) |
|
|
return DATASET_PATH, DATASET_NAME, split, requests, task_obj, docs |
|
|
|
|
|
@abstractmethod |
|
|
def embed_task(self, task: str): |
|
|
return |
|
|
|