Spaces:
Runtime error
Runtime error
| # from https://github.com/jmhessel/clipscore/issues/4 | |
| import torch | |
| import random | |
| import scipy | |
| import os | |
| from tqdm import tqdm | |
| class Pascal50sDataset(torch.utils.data.Dataset): | |
| def __init__(self, | |
| root: str = "data/Pascal-50s/", | |
| media_size: int = 224, | |
| voc_path: str = "data/VOC2010/"): | |
| super().__init__() | |
| self.voc_path = voc_path | |
| self.fix_seed() | |
| self.read_data(root) | |
| self.read_score(root) | |
| self.idx2cat = {1: 'HC', 2: 'HI', 3: 'HM', 4: 'MM'} | |
| def loadmat(path): | |
| return scipy.io.loadmat(path) | |
| def fix_seed(self, seed=42): | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| def read_data(self, root): | |
| mat = self.loadmat( | |
| os.path.join(root, "pyCIDErConsensus/pair_pascal.mat")) | |
| self.data = mat["new_input"][0] | |
| self.categories = mat["category"][0] | |
| # sanity check | |
| c = torch.Tensor(mat["new_data"]) | |
| hc = (c.sum(dim=-1) == 12).int() | |
| hi = (c.sum(dim=-1) == 13).int() | |
| hm = ((c < 6).sum(dim=-1) == 1).int() | |
| mm = ((c < 6).sum(dim=-1) == 2).int() | |
| assert 1000 == hc.sum() | |
| assert 1000 == hi.sum() | |
| assert 1000 == hm.sum() | |
| assert 1000 == mm.sum() | |
| assert (hc + hi + hm + mm).sum() == self.categories.shape[0] | |
| chk = (torch.Tensor(self.categories) - hc - hi * 2 - hm * 3 - mm * 4) | |
| assert 0 == chk.abs().sum(), chk | |
| def read_score(self, root): | |
| mat = self.loadmat( | |
| os.path.join(root, "pyCIDErConsensus/consensus_pascal.mat")) | |
| data = mat["triplets"][0] | |
| self.labels = [] | |
| self.references = [] | |
| for i in range(len(self)): | |
| votes = {} | |
| refs = [] | |
| for j in range(i * 48, (i + 1) * 48): | |
| a,b,c,d = [x[0][0] for x in data[j]] | |
| key = b[0].strip() if 1 == d else c[0].strip() | |
| refs.append(a[0].strip()) | |
| votes[key] = votes.get(key, 0) + 1 | |
| assert 2 >= len(votes.keys()), votes | |
| assert len(votes.keys()) > 0 | |
| try: | |
| vote_a = votes.get(self.data[i][1][0].strip(), 0) | |
| vote_b = votes.get(self.data[i][2][0].strip(), 0) | |
| except KeyError: | |
| print("warning: data mismatch!") | |
| print(f"a: {self.data[i][1][0].strip()}") | |
| print(f"b: {self.data[i][2][0].strip()}") | |
| print(votes) | |
| exit() | |
| # Ties are broken randomly. | |
| label = 0 if vote_a > vote_b + random.random() - .5 else 1 # a == bの場合は0.5の確率で0か1を選ぶ | |
| self.labels.append(label) | |
| self.references.append(refs) | |
| def __len__(self): | |
| return len(self.data) | |
| def get_image_path(self, filename: str): | |
| path = os.path.join(self.voc_path, "JPEGImages") | |
| return os.path.join(path, filename) | |
| def __getitem__(self, idx: int): | |
| vid, a, b = [x[0] for x in self.data[idx]] | |
| label = self.labels[idx] | |
| img_path = self.get_image_path(vid) | |
| a = a.strip() | |
| b = b.strip() | |
| references = self.references[idx] | |
| category = self.categories[idx] | |
| category_str = self.idx2cat[category] | |
| return img_path, a, b, references, category_str, label | |
| def sanity_check(detail=False): | |
| # sanity check | |
| dataset = Pascal50sDataset(root="pascal/", voc_path="pascal/VOCdevkit/VOC2010") | |
| one_sample = dataset[0] | |
| assert one_sample is not None | |
| dprint = lambda *args, **kwargs: print(*args, **kwargs) if detail else None | |
| for it, one_sample in enumerate(tqdm(dataset)): | |
| dprint("="*20) | |
| dprint("sample:",it) | |
| dprint("="*20) | |
| img_path, a, b, references, category, label = one_sample | |
| assert os.path.exists(img_path) | |
| dprint("img_path:", img_path) | |
| dprint("a:", a) | |
| dprint("b:", b) | |
| dprint("references:", references) | |
| dprint("category:", category) | |
| dprint("label:", label) | |
| if __name__ == "__main__": | |
| sanity_check(detail=False) | |