| |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
| import os |
| import sys |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from config import NUM_EPOCHS, LEARNING_RATE, MODEL_DIR, DEVICE |
| from models.king_ai import KingAI |
| from data.dataset import get_dataloaders |
|
|
|
|
| def train(): |
| """训练行为克隆模型""" |
| |
| if torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| print("✅ 使用 MPS (Apple Silicon GPU) 加速") |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| print("✅ 使用 CUDA (NVIDIA GPU) 加速") |
| else: |
| device = torch.device("cpu") |
| print("⚠️ 使用 CPU 训练") |
| |
| |
| print("\n加载数据...") |
| train_loader, val_loader = get_dataloaders( |
| frames_dir="data/frames/game_01", |
| annotation_file="data/annotations/annotations.json" |
| ) |
| |
| |
| model = KingAI().to(device) |
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) |
| |
| print(f"\n开始训练 {NUM_EPOCHS} 轮...") |
| print("=" * 50) |
| |
| best_acc = 0.0 |
| |
| for epoch in range(NUM_EPOCHS): |
| |
| model.train() |
| train_loss = 0.0 |
| train_correct = 0 |
| train_total = 0 |
| |
| for images, actions in train_loader: |
| images, actions = images.to(device), actions.to(device) |
| |
| optimizer.zero_grad() |
| outputs = model(images) |
| loss = criterion(outputs, actions) |
| loss.backward() |
| optimizer.step() |
| |
| train_loss += loss.item() |
| _, predicted = torch.max(outputs, 1) |
| train_total += actions.size(0) |
| train_correct += (predicted == actions).sum().item() |
| |
| train_acc = 100 * train_correct / train_total |
| |
| |
| model.eval() |
| val_loss = 0.0 |
| val_correct = 0 |
| val_total = 0 |
| |
| with torch.no_grad(): |
| for images, actions in val_loader: |
| images, actions = images.to(device), actions.to(device) |
| outputs = model(images) |
| loss = criterion(outputs, actions) |
| |
| val_loss += loss.item() |
| _, predicted = torch.max(outputs, 1) |
| val_total += actions.size(0) |
| val_correct += (predicted == actions).sum().item() |
| |
| val_acc = 100 * val_correct / val_total |
| |
| scheduler.step() |
| |
| print(f"Epoch [{epoch+1:3d}/{NUM_EPOCHS}] " |
| f"Train Loss: {train_loss/len(train_loader):.4f} " |
| f"Train Acc: {train_acc:.2f}% | " |
| f"Val Loss: {val_loss/len(val_loader):.4f} " |
| f"Val Acc: {val_acc:.2f}%") |
| |
| |
| if val_acc > best_acc: |
| best_acc = val_acc |
| torch.save(model.state_dict(), os.path.join(MODEL_DIR, "best_model.pth")) |
| print(f" ✅ 保存最佳模型 (准确率: {val_acc:.2f}%)") |
| |
| |
| torch.save(model.state_dict(), os.path.join(MODEL_DIR, "final_model.pth")) |
| print(f"\n🎉 训练完成!最佳验证准确率: {best_acc:.2f}%") |
| print(f"模型保存在: {MODEL_DIR}") |
|
|
|
|
| if __name__ == "__main__": |
| train() |