rhfeiyang commited on
Commit
b65af75
·
1 Parent(s): 36dcf5f
custom_datasets/custom_caption.py CHANGED
@@ -49,7 +49,7 @@ class Caption_set(torch.utils.data.Dataset):
49
 
50
 
51
  class HRS_caption(Caption_set):
52
- def __init__(self, prompts_path="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv", transform=None, delimiter=','):
53
  self.prompts = pd.read_csv(prompts_path, delimiter=delimiter)
54
  self.transform = transform
55
  self.caption_key = "original_prompts"
@@ -65,7 +65,7 @@ class HRS_caption(Caption_set):
65
  return ret
66
 
67
  class Laion_pop(torch.utils.data.Dataset):
68
- def __init__(self, anno_file="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/custom_datasets/laion_pop500.csv",image_root="/vision-nfs/torralba/scratch/jomat/sam_dataset/laion_pop",transform=None):
69
  self.transform = transform
70
  self.info = pd.read_csv(anno_file, delimiter=";")
71
  self.caption_key = "caption"
 
49
 
50
 
51
  class HRS_caption(Caption_set):
52
+ def __init__(self, prompts_path="/data/vision/torralba/selfmanaged/torralba/projects/jomat/hui/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv", transform=None, delimiter=','):
53
  self.prompts = pd.read_csv(prompts_path, delimiter=delimiter)
54
  self.transform = transform
55
  self.caption_key = "original_prompts"
 
65
  return ret
66
 
67
  class Laion_pop(torch.utils.data.Dataset):
68
+ def __init__(self, anno_file="/data/vision/torralba/selfmanaged/torralba/projects/jomat/hui/stable_diffusion/custom_datasets/laion_pop500.csv",image_root="/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/laion_pop",transform=None):
69
  self.transform = transform
70
  self.info = pd.read_csv(anno_file, delimiter=";")
71
  self.caption_key = "caption"
custom_datasets/filt/coco/filt.py CHANGED
@@ -109,10 +109,10 @@ def main(args):
109
  filter.clip_filter = None
110
  torch.cuda.empty_cache()
111
 
112
- # caption_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/PixArt-alpha/captions"
113
- # image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/images"
114
- # id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict"
115
- # filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
116
 
117
  def collate_fn(examples):
118
  # {"image": image, "id":id}
@@ -142,7 +142,7 @@ def main(args):
142
 
143
 
144
 
145
- save_root = f"/vision-nfs/torralba/scratch/jomat/sam_dataset/coco/filt/{args.split}"
146
  os.makedirs(save_root, exist_ok=True)
147
 
148
  if args.mode == "clip_feat":
 
109
  filter.clip_filter = None
110
  torch.cuda.empty_cache()
111
 
112
+ # caption_folder_path = "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/PixArt-alpha/captions"
113
+ # image_folder_path = "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/images"
114
+ # id_dict_dir = "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/images/id_dict"
115
+ # filt_dir = "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/filt_result"
116
 
117
  def collate_fn(examples):
118
  # {"image": image, "id":id}
 
142
 
143
 
144
 
145
+ save_root = f"/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/coco/filt/{args.split}"
146
  os.makedirs(save_root, exist_ok=True)
147
 
148
  if args.mode == "clip_feat":
custom_datasets/filt/sam_filt.py CHANGED
@@ -31,9 +31,9 @@ def main(args):
31
  torch.cuda.empty_cache()
32
 
33
  caption_folder_path = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/SAM/subset/captions"
34
- image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/nfs-data/sam/images"
35
- id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/sam_ids/8.16/id_dict"
36
- filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
37
  def collate_fn(examples):
38
  # {"image": image, "id":id}
39
  ret = {}
@@ -71,7 +71,7 @@ def main(args):
71
  id_dict = pickle.load(f)
72
  ids = list(id_dict.keys())
73
  dataset = SamDataset(image_folder_path, caption_folder_path, id_file=ids, id_dict_file=id_dict_file)
74
- # dataset = SamDataset(image_folder_path, caption_folder_path, id_file=[10061410, 10076945, 10310013,1042012, 4487809, 4541052], id_dict_file="/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict/all_id_dict.pickle")
75
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
76
  clip_logits = None
77
  clip_logits_file = os.path.join(save_dir, "clip_logits_result.pickle")
 
31
  torch.cuda.empty_cache()
32
 
33
  caption_folder_path = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/SAM/subset/captions"
34
+ image_folder_path = "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/nfs-data/sam/images"
35
+ id_dict_dir = "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/sam_ids/8.16/id_dict"
36
+ filt_dir = "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/filt_result"
37
  def collate_fn(examples):
38
  # {"image": image, "id":id}
39
  ret = {}
 
71
  id_dict = pickle.load(f)
72
  ids = list(id_dict.keys())
73
  dataset = SamDataset(image_folder_path, caption_folder_path, id_file=ids, id_dict_file=id_dict_file)
74
+ # dataset = SamDataset(image_folder_path, caption_folder_path, id_file=[10061410, 10076945, 10310013,1042012, 4487809, 4541052], id_dict_file="/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/images/id_dict/all_id_dict.pickle")
75
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
76
  clip_logits = None
77
  clip_logits_file = os.path.join(save_dir, "clip_logits_result.pickle")
custom_datasets/mypath.py CHANGED
@@ -5,7 +5,7 @@ class MyPath(object):
5
  @staticmethod
6
  def db_root_dir(database=''):
7
  coco_root = "/data/vision/torralba/datasets/coco_2017"
8
- sam_caption_root = "/vision-nfs/torralba/datasets/vision/sam/captions"
9
 
10
  root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
  map={
@@ -13,7 +13,7 @@ class MyPath(object):
13
  "coco_caption_train": f"{coco_root}/annotations/captions_train2017.json",
14
  "coco_val": f"{coco_root}/val2017/",
15
  "coco_caption_val": f"{coco_root}/annotations/captions_val2017.json",
16
- "sam_images": "/vision-nfs/torralba/datasets/vision/sam/images",
17
  "sam_captions": sam_caption_root,
18
  "sam_whole_filtered_ids_train": "data/filtered_sam/all_remain_ids_train.pickle",
19
  "sam_whole_filtered_ids_val": "data/filtered_sam/all_remain_ids_val.pickle",
 
5
  @staticmethod
6
  def db_root_dir(database=''):
7
  coco_root = "/data/vision/torralba/datasets/coco_2017"
8
+ sam_caption_root = "/data/vision/torralba/selfmanaged/torralba/datasets/vision/sam/captions"
9
 
10
  root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
  map={
 
13
  "coco_caption_train": f"{coco_root}/annotations/captions_train2017.json",
14
  "coco_val": f"{coco_root}/val2017/",
15
  "coco_caption_val": f"{coco_root}/annotations/captions_val2017.json",
16
+ "sam_images": "/data/vision/torralba/selfmanaged/torralba/datasets/vision/sam/images",
17
  "sam_captions": sam_caption_root,
18
  "sam_whole_filtered_ids_train": "data/filtered_sam/all_remain_ids_train.pickle",
19
  "sam_whole_filtered_ids_val": "data/filtered_sam/all_remain_ids_val.pickle",
utils/art_filter.py CHANGED
@@ -202,7 +202,7 @@ class Art_filter:
202
 
203
  if __name__ == "__main__":
204
  import pickle
205
- with open("/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result/sa_000000/clip_logits_result.pickle","rb") as f:
206
  result=pickle.load(f)
207
  feat = result['clip_features']
208
  logits =Art_filter().clip_logit_by_feat(feat)
 
202
 
203
  if __name__ == "__main__":
204
  import pickle
205
+ with open("/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/filt_result/sa_000000/clip_logits_result.pickle","rb") as f:
206
  result=pickle.load(f)
207
  feat = result['clip_features']
208
  logits =Art_filter().clip_logit_by_feat(feat)