|
|
import os |
|
|
import sys |
|
|
from typing import Dict, List, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from lmms_eval.tasks import initialize_tasks |
|
|
|
|
|
from .BaseShrinker import BaseShrinker |
|
|
from .sampling_methods import AVAILABEL_METHODS |
|
|
|
|
|
sys.path.append("..") |
|
|
from embedder import BaseEmbedder |
|
|
from shrinker import sampling_methods as sampling_methods_module |
|
|
|
|
|
|
|
|
class Embed_Shrinker(BaseShrinker): |
|
|
def __init__( |
|
|
self, |
|
|
task: str, |
|
|
num_items: Union[int, float], |
|
|
name: str, |
|
|
embed_cache_path: str, |
|
|
sampling_methods: str, |
|
|
push_to_hub: bool, |
|
|
) -> None: |
|
|
super().__init__(task, num_items, name, push_to_hub) |
|
|
self.embed_cache_path = embed_cache_path |
|
|
initialize_tasks() |
|
|
self.DATASET_PATH, self.DATASET_NAME, self.split, _, self.task_obj, docs = BaseEmbedder.init_task(task) |
|
|
assert sampling_methods in AVAILABEL_METHODS, f"Not available sampling methods, Choose from {AVAILABEL_METHODS.keys()}" |
|
|
self.sampling_methods = getattr(sampling_methods_module, AVAILABEL_METHODS[sampling_methods]) |
|
|
|
|
|
def shrink(self): |
|
|
task_embedding = np.load(open(os.path.join(self.embed_cache_path, f"{self.task}_embed.npy"), "rb")) |
|
|
task_embedding = torch.from_numpy(task_embedding) |
|
|
|
|
|
|
|
|
if len(task_embedding.shape) == 3: |
|
|
task_embedding = task_embedding.squeeze(1) |
|
|
task_embedding = task_embedding.numpy() |
|
|
self.sampling_methods = self.sampling_methods(X=task_embedding) |
|
|
|
|
|
if self.num_items < 1.0: |
|
|
self.num_items = int(task_embedding.shape[0] * self.num_items) |
|
|
else: |
|
|
self.num_items = int(self.num_items) |
|
|
anchor_points = self.sampling_methods.select_batch(N=self.num_items) |
|
|
dataset = self.task_obj.dataset[self.split] |
|
|
tiny_dataset = dataset.select(anchor_points) |
|
|
|
|
|
if self.push_to_hub: |
|
|
tiny_dataset.push_to_hub(repo_id=f"lmms-lab/LMMs-Eval-Lite", config_name=self.task, split="lite") |
|
|
return |
|
|
|