DMNSP / continual_clip /dynamic_dataset.py
boringKey's picture
Upload 46 files
b389d26 verified
import os
import re
import torch
from tqdm import tqdm
import numpy as np
# from torchvision import datasetsss
from . import datasets, utils
# import clip.clip as clip
import clip
import numpy as np
class DynamicDataset():
def __init__(self, cfg):
self.ref_database = {} # all data key = dataset_name; value is 4D tensor (num_image, 3, 224, 224)
self.ref_names = [] # collect the name of the dataset
self.ref_model, _, self.test_preprocess = clip.load(cfg.model_name, jit=False)
self.cur_dataset = None
self.memory_size = 5000
self.batch_id = 0
def update(self, dataset, load):
# load is a model directly
self.cur_dataset = dataset
if not load: # first round USELESS FOR CICL
new_dataset = self.getNewDataset()
self.ref_database[self.cur_dataset] = new_dataset[:self.memory_size]
self.ref_names.append(self.cur_dataset)
else: # other rounds
self.ref_model = load
self.reduceExampleSet()
self.constructExampleSet()
def reduceExampleSet(self):
print("Reducing Example Set")
K, t = self.memory_size, len(self.ref_names)+1
m = K // t
for dataset in self.ref_names:
self.ref_database[dataset] = self.ref_database[dataset][:m]
def constructExampleSet(self):
# breakpoint()
print("Constructing Example Set")
self.ref_names.append(self.batch_id)
new_dataset = torch.tensor(self.getNewDataset())
image_feature = []
num = new_dataset.shape[0]
print("[Constructing] Calculating Distance")
for ndx in tqdm(np.arange(num)):
img = torch.unsqueeze(new_dataset[ndx], dim=0)
img = img.cuda()
img_feature = self.ref_model(img, None)
image_feature.append(img_feature.cpu().detach().tolist())
image_feature = torch.tensor(image_feature)
image_feature = torch.squeeze(image_feature, dim=1)
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
image_feature = np.array(image_feature.cpu().detach())
image_feature_average = image_feature.mean(axis=0)
K, t = self.memory_size, len(self.ref_names)
m = K - K // t
update_dataset = []
if not m:
m = self.memory_size
cur_embedding_sum = None
print("[Constructing] Collecting Examples")
for k in tqdm(np.arange(min(m, len(image_feature)))):
if not k:
index = np.argmin(
np.sum((image_feature_average - image_feature)**2, axis=1)
)
cur_embedding_sum = image_feature[index]
update_dataset.append((new_dataset.cpu())[index].tolist())
image_feature = np.delete(image_feature, index, axis=0)
else:
index = np.argmin(
np.sum((
image_feature_average - (1/(k+1))*(image_feature + cur_embedding_sum)
)**2, axis=1)
)
cur_embedding_sum += image_feature[index]
update_dataset.append((new_dataset.cpu())[index].tolist())
image_feature = np.delete(image_feature, index, axis=0)
self.ref_database[self.batch_id] = update_dataset
print("finishing current task", self.batch_id)
self.batch_id = self.batch_id + 1
def getNewDataset(self):
samples = []
count = 0
for sample in tqdm(self.cur_dataset):
if count == 10000:
return samples
count += 1
samples.append(sample[0].tolist())
return samples
def get(self):
print("Getting Reference Images")
value = list(self.ref_database.values())
out = []
for i in tqdm(value):
out += i
return torch.tensor(out)