giorgio-caparvi commited on
Commit
9aabd56
·
1 Parent(s): 35c6408

passing caption through json

Browse files
api/model/src/datasets/dresscode.py CHANGED
@@ -25,6 +25,7 @@ class DressCodeDataset(data.Dataset):
25
  dataroot_path: str,
26
  phase: str,
27
  tokenizer,
 
28
  radius=5,
29
  caption_folder='fine_captions.json',
30
  coarse_caption_folder='coarse_captions.json',
@@ -48,6 +49,7 @@ class DressCodeDataset(data.Dataset):
48
  self.height = size[0]
49
  self.width = size[1]
50
  self.radius = radius
 
51
  self.tokenizer = tokenizer
52
  self.transform = transforms.Compose([
53
  transforms.ToTensor(),
@@ -71,12 +73,19 @@ class DressCodeDataset(data.Dataset):
71
  assert all(x in possible_outputs for x in outputlist)
72
 
73
  # Load Captions
 
74
  with open(self.dataroot / self.caption_folder) as f:
75
  self.captions_dict = json.load(f)
76
  self.captions_dict = {k: v for k, v in self.captions_dict.items() if len(v) >= 3}
77
 
78
  with open(self.dataroot / coarse_caption_folder) as f:
79
  self.captions_dict.update(json.load(f))
 
 
 
 
 
 
80
 
81
  for c in category:
82
  assert c in ['dresses', 'upper_body', 'lower_body']
 
25
  dataroot_path: str,
26
  phase: str,
27
  tokenizer,
28
+ json_from_req,
29
  radius=5,
30
  caption_folder='fine_captions.json',
31
  coarse_caption_folder='coarse_captions.json',
 
49
  self.height = size[0]
50
  self.width = size[1]
51
  self.radius = radius
52
+ self.json_from_req = json_from_req
53
  self.tokenizer = tokenizer
54
  self.transform = transforms.Compose([
55
  transforms.ToTensor(),
 
73
  assert all(x in possible_outputs for x in outputlist)
74
 
75
  # Load Captions
76
+ '''
77
  with open(self.dataroot / self.caption_folder) as f:
78
  self.captions_dict = json.load(f)
79
  self.captions_dict = {k: v for k, v in self.captions_dict.items() if len(v) >= 3}
80
 
81
  with open(self.dataroot / coarse_caption_folder) as f:
82
  self.captions_dict.update(json.load(f))
83
+ '''
84
+
85
+ # Load Captions
86
+ model_data = self.json_from_req.get('MODEL', {}) # Safely get the 'MODEL' key, default to an empty dictionary if it doesn't exist
87
+ # Filter captions based on the length requirement (3 or more items)
88
+ self.captions_dict = {k: v for k, v in model_data.items() if len(v) >= 3}
89
 
90
  for c in category:
91
  assert c in ['dresses', 'upper_body', 'lower_body']