MarcoParola commited on
Commit
dda292a
·
1 Parent(s): 4a4d7e2

fix random image id random generation

Browse files
Files changed (1) hide show
  1. src/utils.py +7 -1
src/utils.py CHANGED
@@ -10,7 +10,13 @@ config = yaml.safe_load(open("./config/config.yaml"))
10
  def get_random_image_id(class_idx, data_dir):
11
  path = os.path.join(data_dir, 'images', str(class_idx))
12
  images = os.listdir(path)
13
- return np.random.randint(0, len(images))
 
 
 
 
 
 
14
 
15
  def load_image_and_saliency(class_idx, data_dir, img_id):
16
  path = os.path.join(data_dir, 'images', str(class_idx))
 
10
  def get_random_image_id(class_idx, data_dir):
11
  path = os.path.join(data_dir, 'images', str(class_idx))
12
  images = os.listdir(path)
13
+ ids = [int(img.split('.')[0]) for img in images if img.endswith('.png')]
14
+ if not ids:
15
+ raise ValueError(f"No images found for class index {class_idx} in {path}")
16
+ # set random seed using time
17
+ np.random.seed(int(time.time()))
18
+ random_id = np.random.randint(0, len(ids))
19
+ return ids[random_id]
20
 
21
  def load_image_and_saliency(class_idx, data_dir, img_id):
22
  path = os.path.join(data_dir, 'images', str(class_idx))