Model8-GAT / Train2model.py
JKL0909's picture
commit first
69620d2
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.multiprocessing import Process, set_start_method
try:
set_start_method('spawn')
except RuntimeError:
pass
# Cấu hình
BATCH_SIZE = 32
EPOCHS = 10
NUM_CLASSES = 2
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_ROOTS = [
'/home/ubuntu/vnet/TaoST/Data10kKaggle1',
'/home/ubuntu/vnet/TaoST/Data10kKaggle2'
]
MODEL_PATHS = [
'/home/ubuntu/vnet/FL/efficientnet_b0_kaggle1.pth',
'/home/ubuntu/vnet/FL/efficientnet_b0_kaggle2.pth'
]
def get_loaders(data_root):
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
train_set = datasets.ImageFolder(os.path.join(data_root, 'train'), transform=train_transform)
test_set = datasets.ImageFolder(os.path.join(data_root, 'test'), transform=test_transform)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
return train_loader, test_loader
def train_model(data_root, model_path):
train_loader, test_loader = get_loaders(data_root)
model = models.efficientnet_b0(weights='IMAGENET1K_V1')
model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(EPOCHS):
model.train()
running_loss = 0.0
for imgs, labels in train_loader:
imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * imgs.size(0)
print(f"[{data_root}] Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader.dataset):.4f}")
torch.save(model.state_dict(), model_path)
print(f"Saved model to {model_path}")
def main():
p1 = Process(target=train_model, args=(DATA_ROOTS[0], MODEL_PATHS[0]))
p2 = Process(target=train_model, args=(DATA_ROOTS[1], MODEL_PATHS[1]))
p1.start()
p2.start()
p1.join()
p2.join()
if __name__ == "__main__":
main()