File size: 2,295 Bytes
b0c0df0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
import torch.distributed as dist
from lmms_eval.api.registry import ALL_TASKS
from lmms_eval.tasks import (
ConfigurableTask,
get_task_dict,
include_path,
initialize_tasks,
)
def rank0_print(*args):
if dist.is_initialized():
if dist.get_rank() == 0:
print(f"Rank {dist.get_rank()}: ", *args)
else:
print(*args)
class BaseEmbedder(ABC):
def __init__(self, name: str, output_path: str) -> None:
super().__init__()
self.name = name
self.output_path = output_path
os.makedirs(self.output_path, exist_ok=True)
initialize_tasks()
def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list
# A static method to build requests for lmms_eval tasks
# Pass in task name and return a list of Requests
@staticmethod
def init_task(task: str, ignored_ids: Union[set, List] = None):
task_dict = get_task_dict([task], model_name="llava")
task_obj = task_dict[task]
if type(task_obj) == tuple:
group, task_obj = task_obj
DATASET_PATH = task_obj.DATASET_PATH
DATASET_NAME = None
if task_obj.DATASET_NAME is not None:
DATASET_NAME = task_obj.DATASET_NAME
docs = task_obj.test_docs() if task_obj.has_test_docs() else task_obj.validation_docs()
split = task_obj.config.test_split if task_obj.has_test_docs() else task_obj.config.validation_split
rank0_print(f"\nTask : {task_obj.config.task}\n - #num : {len(task_obj.test_docs()) if task_obj.has_test_docs() else task_obj.validation_docs()}")
task_obj.build_all_requests()
requests = []
for instance in task_obj.instances:
reqtype = instance.request_type
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = instance.args
if ignored_ids is not None and doc_id in ignored_ids:
continue
requests.append(instance)
return DATASET_PATH, DATASET_NAME, split, requests, task_obj, docs
@abstractmethod
def embed_task(self, task: str):
return
|