Raid41 commited on
Commit
5be6be6
·
1 Parent(s): 09800e1

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +4 -4
train.py CHANGED
@@ -27,14 +27,14 @@ def parse_args():
27
  def get_transforms():
28
  return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
29
 
30
- def get_dataloaders(data_path, transforms, batch_size, fine_tuning, mult_number):
31
- train_dataset = TrainDataset(data_path, transforms, mult_number)
32
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
33
-
34
  if fine_tuning:
35
  finetuning_dataset = FineTuningDataset(data_path, transforms)
36
  finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True)
37
-
38
  return train_dataloader, finetuning_dataloader
39
 
40
  def get_models(device):
 
27
  def get_transforms():
28
  return albu.Compose([albu.RandomCrop(512, 512, always_apply = True), albu.HorizontalFlip(p = 0.5)], p = 1.)
29
 
30
+ def get_dataloaders(data_path, transforms, batch_size, fine_tuning):
31
+ train_dataset = TrainDataset(data_path, transforms)
32
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
33
+
34
  if fine_tuning:
35
  finetuning_dataset = FineTuningDataset(data_path, transforms)
36
  finetuning_dataloader = torch.utils.data.DataLoader(finetuning_dataset, batch_size = batch_size, shuffle = True)
37
+
38
  return train_dataloader, finetuning_dataloader
39
 
40
  def get_models(device):