Spaces:
Build error
Build error
| from .dataset import CustomDataset | |
| from torch.utils.data import DataLoader | |
| from src.configs.model_config import ModelConfig | |
| from .transform import data_transform | |
| import os | |
| from torch.utils.data import random_split | |
| num_classes = 3 | |
| config = ModelConfig().get_config() | |
| all_dataset = CustomDataset(data_folder=os.path.join( | |
| "data", 'raw'), transform=data_transform) | |
| # split to train, val, test | |
| total_size = len(all_dataset) | |
| train_size = int(0.8 * total_size) | |
| val_size = int(0.1 * total_size) | |
| test_size = total_size - train_size - val_size | |
| train_dataset, val_dataset, test_dataset = random_split( | |
| all_dataset, [train_size, val_size, test_size]) | |
| train_loader = DataLoader( | |
| train_dataset, batch_size=config.batch_size, shuffle=True) | |
| val_loader = DataLoader( | |
| val_dataset, batch_size=config.batch_size, shuffle=True) | |
| test_loader = DataLoader( | |
| test_dataset, batch_size=config.batch_size, shuffle=True) | |
| def get_train_dataset(batch_size): | |
| return DataLoader( | |
| all_dataset, batch_size=batch_size, shuffle=True) | |