Spaces:
Runtime error
Runtime error
| import h5py | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| from knowledge import TextDB | |
| class ImageCropsIdx: | |
| def __init__(self, knowledge_idx, topk_w, topk_f, topk_n): | |
| topk = {"whole": topk_w, "five": topk_f, "nine": topk_n} | |
| self.topk = {k: v for k, v in topk.items() if v > 0} | |
| self.knowledge_idx, self.fdim, self.file_hash = self.load(knowledge_idx, self.topk) | |
| def load(self, knowledge_idx, topk): | |
| with h5py.File(knowledge_idx, "r") as f: | |
| fdim = f.attrs["fdim"] | |
| file_hash = f.attrs["file_hash"] | |
| knowledge_idx_ = {} | |
| for i in tqdm(range(len(f)), desc="Load sentence idx", dynamic_ncols=True, mininterval=1.0): | |
| knowledge_idx_[str(i)] = {"image_ids": f[f"{i}/image_ids"][:]} | |
| for k, v in topk.items(): | |
| knowledge_idx_[str(i)][k] = { | |
| "index": f[f"{i}/{k}/index"][:, :, :v], | |
| "score": f[f"{i}/{k}/score"][:, :, :v], | |
| "query": f[f"{i}/{k}/query"][:] | |
| } | |
| knowledge_idx = {} | |
| for i in knowledge_idx_.keys(): | |
| for j, id in enumerate(knowledge_idx_[i]["image_ids"]): | |
| knowledge_idx[id] = {} | |
| for k in topk.keys(): | |
| knowledge_idx[id][k] = { | |
| "index": knowledge_idx_[i][k]["index"][j], | |
| "score": knowledge_idx_[i][k]["score"][j], | |
| "query": knowledge_idx_[i][k]["query"][j], | |
| } | |
| return knowledge_idx, fdim, file_hash | |
| def __getitem__(self, image_id): | |
| return self.knowledge_idx[image_id] | |
| class KnowAugImageCrops: | |
| def __init__(self, knowledge_db: TextDB, knowledge_idx: ImageCropsIdx, return_txt=False): | |
| self.knowledge_db = knowledge_db | |
| self.knowledge_idx = knowledge_idx | |
| assert knowledge_db.file_hash == knowledge_idx.file_hash | |
| self.ncrop = {"whole": 1, "five": 5, "nine": 9} | |
| self.topk = knowledge_idx.topk | |
| self.fdim = knowledge_idx.fdim | |
| self.return_txt = return_txt | |
| def __call__(self, image_id): | |
| ret = {} | |
| for k in self.topk.keys(): | |
| ki = self.knowledge_idx[image_id][k]["index"].flatten() | |
| ke, kt = self.knowledge_db[ki] | |
| kq = self.knowledge_idx[image_id][k]["query"] | |
| kp = np.tile(np.arange(self.ncrop[k])[:, None], (1, self.topk[k])).flatten() | |
| ks = self.knowledge_idx[image_id][k]["score"].flatten() | |
| ke = torch.FloatTensor(ke) | |
| kq = torch.FloatTensor(kq) | |
| kp = torch.LongTensor(kp) | |
| ks = torch.FloatTensor(ks) | |
| ret[k] = {"embed": ke, "query": kq, "pos": kp, "score": ks} | |
| if self.return_txt: | |
| ret[k]["text"] = kt | |
| return ret | |
| class KnowAugImageCropsCombined: | |
| def __init__( | |
| self, | |
| knwl_aug_obj: KnowAugImageCrops, | |
| knwl_aug_attr: KnowAugImageCrops, | |
| knwl_aug_act: KnowAugImageCrops | |
| ): | |
| self.knwl_aug_obj = knwl_aug_obj | |
| self.knwl_aug_act = knwl_aug_act | |
| self.knwl_aug_attr = knwl_aug_attr | |
| self.fdim = knwl_aug_obj.fdim | |
| def __call__(self, image_id): | |
| knwl_obj = self.knwl_aug_obj(image_id) | |
| knwl_attr = self.knwl_aug_attr(image_id) | |
| knwl_act = self.knwl_aug_act(image_id) | |
| ret = {} | |
| for k in knwl_obj.keys(): | |
| ret[k] = { | |
| "obj": knwl_obj[k], | |
| "attr": knwl_attr[k], | |
| "act": knwl_act[k] | |
| } | |
| return ret | |