Raid41 commited on
Commit
1d7df32
·
1 Parent(s): fe15f7e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -0
train.py CHANGED
@@ -29,15 +29,18 @@ def get_transforms():
29
 
30
  def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
31
  train_dataset = TrainDataset(data_path, transforms, mults_amount=mult_number)
 
32
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
33
 
34
  if fine_tuning:
35
  finetuning_dataset = FineTuningDataset(data_path, transforms, mult_amount=mult_number)
 
36
  finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size=batch_size, shuffle=True)
37
  return train_dataloader, finetuning_dataloader
38
 
39
  return train_dataloader, None
40
 
 
41
  def get_models(device):
42
  generator = Generator()
43
  extractor = get_seresnext_extractor()
 
29
 
30
  def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
31
  train_dataset = TrainDataset(data_path, transforms, mults_amount=mult_number)
32
+ print("Train Dataset Length:", len(train_dataset)) # Debug print
33
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
34
 
35
  if fine_tuning:
36
  finetuning_dataset = FineTuningDataset(data_path, transforms, mult_amount=mult_number)
37
+ print("FineTuning Dataset Length:", len(finetuning_dataset)) # Debug print
38
  finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size=batch_size, shuffle=True)
39
  return train_dataloader, finetuning_dataloader
40
 
41
  return train_dataloader, None
42
 
43
+
44
  def get_models(device):
45
  generator = Generator()
46
  extractor = get_seresnext_extractor()