import os import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms, models from torch.utils.data import DataLoader from torch.multiprocessing import Process, set_start_method try: set_start_method('spawn') except RuntimeError: pass # Cấu hình BATCH_SIZE = 32 EPOCHS = 10 NUM_CLASSES = 2 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' DATA_ROOTS = [ '/home/ubuntu/vnet/TaoST/Data10kKaggle1', '/home/ubuntu/vnet/TaoST/Data10kKaggle2' ] MODEL_PATHS = [ '/home/ubuntu/vnet/FL/efficientnet_b0_kaggle1.pth', '/home/ubuntu/vnet/FL/efficientnet_b0_kaggle2.pth' ] def get_loaders(data_root): train_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) test_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) train_set = datasets.ImageFolder(os.path.join(data_root, 'train'), transform=train_transform) test_set = datasets.ImageFolder(os.path.join(data_root, 'test'), transform=test_transform) train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) return train_loader, test_loader def train_model(data_root, model_path): train_loader, test_loader = get_loaders(data_root) model = models.efficientnet_b0(weights='IMAGENET1K_V1') model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES) model = model.to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) for epoch in range(EPOCHS): model.train() running_loss = 0.0 for imgs, labels in train_loader: imgs, labels = imgs.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() outputs = model(imgs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * imgs.size(0) print(f"[{data_root}] Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader.dataset):.4f}") torch.save(model.state_dict(), model_path) print(f"Saved model to {model_path}") def main(): p1 = Process(target=train_model, args=(DATA_ROOTS[0], MODEL_PATHS[0])) p2 = Process(target=train_model, args=(DATA_ROOTS[1], MODEL_PATHS[1])) p1.start() p2.start() p1.join() p2.join() if __name__ == "__main__": main()