Commit ·
dd9e164
0
Parent(s):
Duplicate from clarenceleo/HonorNet_v1
Browse files- .gitattributes +35 -0
- LICENSE +21 -0
- README.md +121 -0
- config.py +47 -0
- inference/action_mapper.py +169 -0
- inference/game_controller.py +33 -0
- inference/run_ai.py +106 -0
- models/best_model.pth +3 -0
- models/final_model.pth +3 -0
- models/king_ai.py +119 -0
- models/train_bc.py +109 -0
- requirements.txt +7 -0
- tests/test_mapper.py +163 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 [Tianyi Li]
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HonorNet 🎮
|
| 2 |
+
|
| 3 |
+
> 一个基于行为克隆的王者荣耀AI,从零开始,在Mac上训练,在Android模拟器中运行。
|
| 4 |
+
|
| 5 |
+
[](LICENSE)
|
| 6 |
+
[](https://www.python.org/)
|
| 7 |
+
[](https://pytorch.org/)
|
| 8 |
+
|
| 9 |
+
## ✨ 项目简介
|
| 10 |
+
|
| 11 |
+
**HonorNet** 是一个完全开源的王者荣耀AI项目。它的特点:
|
| 12 |
+
|
| 13 |
+
- 🍎 **在Mac上训练**:利用Apple Silicon的MPS加速
|
| 14 |
+
- 📱 **控制Android模拟器**:通过ADB截图和发送触摸事件
|
| 15 |
+
- 🧠 **行为克隆**:学习人类玩家的操作数据
|
| 16 |
+
- 🎯 **16个动作空间**:移动8方向 + 技能 + 攻击 + 战术
|
| 17 |
+
- 🔓 **MIT协议**:完全开源,任意使用
|
| 18 |
+
|
| 19 |
+
这不是一个“调用API”的demo,而是一个**从数据采集到模型训练再到部署**的完整工程。
|
| 20 |
+
|
| 21 |
+
## 🎯 项目状态
|
| 22 |
+
|
| 23 |
+
| 阶段 | 状态 |
|
| 24 |
+
|------|------|
|
| 25 |
+
| 数据采集 | ✅ 完成(886帧标注) |
|
| 26 |
+
| 行为克隆训练 | ✅ 完成(验证准确率54.5%) |
|
| 27 |
+
| 模拟器部署 | ✅ 完成 |
|
| 28 |
+
| 强化学习微调 | 🚧 进行中 |
|
| 29 |
+
|
| 30 |
+
## 🏗️ 项目结构
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
HonorNet/
|
| 34 |
+
├── data/ # 数据处理脚本
|
| 35 |
+
├── models/ # 模型定义 + 训练
|
| 36 |
+
├── inference/ # 模拟器控制 + AI运行
|
| 37 |
+
├── config.py # 配置文件
|
| 38 |
+
└── requirements.txt # 依赖
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## 🚀 快速开始
|
| 42 |
+
|
| 43 |
+
### 1. 环境配置
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# 克隆仓库
|
| 47 |
+
git lfs clone https://huggingface.co/clarenceleo/HonorNet_v1
|
| 48 |
+
cd HonorNet_v1
|
| 49 |
+
|
| 50 |
+
# 安装依赖
|
| 51 |
+
pip install -r requirements.txt
|
| 52 |
+
|
| 53 |
+
# 安装ADB(macOS)
|
| 54 |
+
brew install android-platform-tools
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### 2. 准备数据
|
| 58 |
+
|
| 59 |
+
将王者荣耀1v1录屏放入 `data/raw_videos/`,然后:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
# 抽帧
|
| 63 |
+
python data/extract_frames.py
|
| 64 |
+
|
| 65 |
+
# 预处理
|
| 66 |
+
python data/preprocess.py
|
| 67 |
+
|
| 68 |
+
# 标注动作(可选,我们提供了标注工具)
|
| 69 |
+
python data/annotate.py
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 3. 训练模型
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
python models/train_bc.py
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### 4. 让AI打游戏
|
| 79 |
+
|
| 80 |
+
1. 启动Android模拟器,打开王者荣耀1v1模式
|
| 81 |
+
2. 运行AI:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
python inference/run_ai.py
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## 📊 训练结果
|
| 88 |
+
|
| 89 |
+
使用886帧标注数据训练50轮:
|
| 90 |
+
|
| 91 |
+
| 指标 | 数值 |
|
| 92 |
+
|------|------|
|
| 93 |
+
| 最佳验证准确率 | **54.49%** |
|
| 94 |
+
| 随机基线 | 6.25% |
|
| 95 |
+
| 训练准确率 | 95.48% |
|
| 96 |
+
|
| 97 |
+
**54.5%的准确率意味着AI学会了**:
|
| 98 |
+
- 根据画面判断移动方向
|
| 99 |
+
- 何时普攻、放技能
|
| 100 |
+
- 何时回城、升级
|
| 101 |
+
|
| 102 |
+
## 🛠️ 技术栈
|
| 103 |
+
|
| 104 |
+
- **PyTorch**:深度学习框架
|
| 105 |
+
- **ADB**:Android调试桥,控制模拟器
|
| 106 |
+
- **OpenCV**:图像处理
|
| 107 |
+
- **Android Studio AVD**:模拟器运行环境
|
| 108 |
+
|
| 109 |
+
## 📈 后续计划
|
| 110 |
+
|
| 111 |
+
- [ ] 强化学习微调(PPO)
|
| 112 |
+
- [ ] 支持更多英雄
|
| 113 |
+
- [ ] 5v5多智能体
|
| 114 |
+
- [ ] 实时学习(边打边学)
|
| 115 |
+
|
| 116 |
+
## 🤝 贡献
|
| 117 |
+
|
| 118 |
+
模型开发与训练数据清洗标注:李天祎(1637321445@qq.com)
|
| 119 |
+
训练数据录制与提供:姜懿原
|
| 120 |
+
|
| 121 |
+
欢迎提交Issue和Pull Request!
|
config.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# 路径配置
|
| 5 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 6 |
+
DATA_DIR = os.path.join(BASE_DIR, "data")
|
| 7 |
+
RAW_VIDEO_DIR = os.path.join(DATA_DIR, "raw_videos")
|
| 8 |
+
FRAMES_DIR = os.path.join(DATA_DIR, "frames")
|
| 9 |
+
ANNOTATIONS_DIR = os.path.join(DATA_DIR, "annotations")
|
| 10 |
+
PROCESSED_DIR = os.path.join(DATA_DIR, "processed")
|
| 11 |
+
MODEL_DIR = os.path.join(BASE_DIR, "models")
|
| 12 |
+
|
| 13 |
+
# 创建目录
|
| 14 |
+
os.makedirs(RAW_VIDEO_DIR, exist_ok=True)
|
| 15 |
+
os.makedirs(FRAMES_DIR, exist_ok=True)
|
| 16 |
+
os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
|
| 17 |
+
os.makedirs(PROCESSED_DIR, exist_ok=True)
|
| 18 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
# 图像处理参数
|
| 21 |
+
IMG_HEIGHT = 84
|
| 22 |
+
IMG_WIDTH = 84
|
| 23 |
+
IMG_CHANNELS = 3
|
| 24 |
+
CROP_TOP_RATIO = 0.08 # 裁剪顶部8%(去掉状态栏)
|
| 25 |
+
CROP_BOTTOM_RATIO = 0.05 # 裁剪底部5%(去掉按钮栏)
|
| 26 |
+
|
| 27 |
+
# 视频抽帧参数
|
| 28 |
+
EXTRACT_FPS = 5 # 每秒抽5帧
|
| 29 |
+
|
| 30 |
+
# 动作空间(16个动作)
|
| 31 |
+
ACTIONS = [
|
| 32 |
+
'move_up', 'move_down', 'move_left', 'move_right',
|
| 33 |
+
'move_upleft', 'move_upright', 'move_downleft', 'move_downright',
|
| 34 |
+
'attack', 'skill_1', 'skill_2', 'skill_3',
|
| 35 |
+
'recall', 'heal', 'summoner', 'upgrade'
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
NUM_ACTIONS = len(ACTIONS)
|
| 39 |
+
|
| 40 |
+
# 训练参数
|
| 41 |
+
BATCH_SIZE = 64
|
| 42 |
+
LEARNING_RATE = 0.0001
|
| 43 |
+
NUM_EPOCHS = 50
|
| 44 |
+
TRAIN_SPLIT = 0.8
|
| 45 |
+
|
| 46 |
+
# 设备
|
| 47 |
+
DEVICE = None # 运行时自动检测
|
inference/action_mapper.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference/action_mapper.py
|
| 2 |
+
import subprocess
|
| 3 |
+
import time
|
| 4 |
+
import io
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
class GameController:
|
| 8 |
+
"""ADB 游戏控制器"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, device_id="emulator-5554"):
|
| 11 |
+
self.device_id = device_id
|
| 12 |
+
|
| 13 |
+
def tap(self, x, y):
|
| 14 |
+
"""点击"""
|
| 15 |
+
cmd = f"adb -s {self.device_id} shell input tap {x} {y}"
|
| 16 |
+
subprocess.run(cmd, shell=True)
|
| 17 |
+
|
| 18 |
+
def swipe(self, x1, y1, x2, y2, duration=50):
|
| 19 |
+
"""滑动"""
|
| 20 |
+
cmd = f"adb -s {self.device_id} shell input swipe {x1} {y1} {x2} {y2} {duration}"
|
| 21 |
+
subprocess.run(cmd, shell=True)
|
| 22 |
+
|
| 23 |
+
def swipe_continuous(self, x1, y1, x2, y2, duration_ms=50):
|
| 24 |
+
"""连续滑动(用于保持移动)"""
|
| 25 |
+
cmd = f"adb -s {self.device_id} shell input swipe {x1} {y1} {x2} {y2} {duration_ms}"
|
| 26 |
+
subprocess.run(cmd, shell=True)
|
| 27 |
+
|
| 28 |
+
def screenshot(self):
|
| 29 |
+
"""截图"""
|
| 30 |
+
cmd = f"adb -s {self.device_id} exec-out screencap -p"
|
| 31 |
+
output = subprocess.check_output(cmd, shell=True)
|
| 32 |
+
return Image.open(io.BytesIO(output))
|
| 33 |
+
|
| 34 |
+
def get_screen_size(self):
|
| 35 |
+
"""获取屏幕分辨率"""
|
| 36 |
+
cmd = f"adb -s {self.device_id} shell wm size"
|
| 37 |
+
output = subprocess.check_output(cmd, shell=True).decode()
|
| 38 |
+
size_str = output.split(":")[1].strip()
|
| 39 |
+
w, h = map(int, size_str.split("x"))
|
| 40 |
+
return w, h
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ActionMapper:
|
| 44 |
+
"""
|
| 45 |
+
动作映射器 - 使用持续滑动保持移动
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, controller):
|
| 49 |
+
self.ctrl = controller
|
| 50 |
+
|
| 51 |
+
# 按钮坐标
|
| 52 |
+
self.buttons = {
|
| 53 |
+
"joystick_center": (448, 861),
|
| 54 |
+
"attack": (1936, 925),
|
| 55 |
+
"skill_1": (1723, 750),
|
| 56 |
+
"skill_2": (1927, 635),
|
| 57 |
+
"skill_3": (1443, 969),
|
| 58 |
+
"recall": (1150, 979),
|
| 59 |
+
"heal": (1283, 979),
|
| 60 |
+
"summoner": (1443, 969),
|
| 61 |
+
"upgrade": (1513, 833),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# 移动目标坐标(摇杆应该滑到的位置)
|
| 65 |
+
self.move_targets = {
|
| 66 |
+
"up": (448, 741),
|
| 67 |
+
"down": (448, 981),
|
| 68 |
+
"left": (328, 861),
|
| 69 |
+
"right": (568, 861),
|
| 70 |
+
"upleft": (363, 776),
|
| 71 |
+
"upright": (533, 776),
|
| 72 |
+
"downleft": (363, 946),
|
| 73 |
+
"downright": (533, 946),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# 当前状态
|
| 77 |
+
self.current_direction = None
|
| 78 |
+
|
| 79 |
+
# 动作映射
|
| 80 |
+
self.action_to_button = {
|
| 81 |
+
"attack": "attack",
|
| 82 |
+
"skill_1": "skill_1",
|
| 83 |
+
"skill_2": "skill_2",
|
| 84 |
+
"skill_3": "skill_3",
|
| 85 |
+
"summoner": "summoner",
|
| 86 |
+
"recall": "recall",
|
| 87 |
+
"heal": "heal",
|
| 88 |
+
"upgrade": "upgrade",
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
def execute(self, action_name):
|
| 92 |
+
"""
|
| 93 |
+
执行动作(每帧调用)
|
| 94 |
+
"""
|
| 95 |
+
# 移动动作
|
| 96 |
+
if action_name.startswith("move_"):
|
| 97 |
+
direction = action_name.replace("move_", "")
|
| 98 |
+
if direction in self.move_targets:
|
| 99 |
+
self._do_move(direction)
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
# 停止移动
|
| 103 |
+
if action_name == "move_stop":
|
| 104 |
+
self._stop_move()
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
# 点击动作
|
| 108 |
+
if action_name in self.action_to_button:
|
| 109 |
+
button = self.action_to_button[action_name]
|
| 110 |
+
if button in self.buttons:
|
| 111 |
+
x, y = self.buttons[button]
|
| 112 |
+
self.ctrl.tap(x, y)
|
| 113 |
+
|
| 114 |
+
def _do_move(self, direction):
|
| 115 |
+
"""
|
| 116 |
+
执行移动:每帧都滑动到目标位置
|
| 117 |
+
这样才能保持英雄持续移动
|
| 118 |
+
"""
|
| 119 |
+
cx, cy = self.buttons["joystick_center"]
|
| 120 |
+
tx, ty = self.move_targets[direction]
|
| 121 |
+
|
| 122 |
+
# 每帧都执行滑动,保持摇杆位置
|
| 123 |
+
self.ctrl.swipe(cx, cy, tx, ty, duration=30)
|
| 124 |
+
self.current_direction = direction
|
| 125 |
+
|
| 126 |
+
def _stop_move(self):
|
| 127 |
+
"""停止移动:摇杆回中心"""
|
| 128 |
+
cx, cy = self.buttons["joystick_center"]
|
| 129 |
+
self.ctrl.swipe(cx, cy, cx, cy, duration=30)
|
| 130 |
+
self.current_direction = None
|
| 131 |
+
|
| 132 |
+
def attack(self):
|
| 133 |
+
"""普攻"""
|
| 134 |
+
x, y = self.buttons["attack"]
|
| 135 |
+
self.ctrl.tap(x, y)
|
| 136 |
+
|
| 137 |
+
def skill_1(self):
|
| 138 |
+
"""技能1"""
|
| 139 |
+
x, y = self.buttons["skill_1"]
|
| 140 |
+
self.ctrl.tap(x, y)
|
| 141 |
+
|
| 142 |
+
def skill_2(self):
|
| 143 |
+
"""技能2"""
|
| 144 |
+
x, y = self.buttons["skill_2"]
|
| 145 |
+
self.ctrl.tap(x, y)
|
| 146 |
+
|
| 147 |
+
def skill_3(self):
|
| 148 |
+
"""技能3"""
|
| 149 |
+
x, y = self.buttons["skill_3"]
|
| 150 |
+
self.ctrl.tap(x, y)
|
| 151 |
+
|
| 152 |
+
def recall(self):
|
| 153 |
+
"""回城"""
|
| 154 |
+
x, y = self.buttons["recall"]
|
| 155 |
+
self.ctrl.tap(x, y)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
ctrl = GameController()
|
| 160 |
+
mapper = ActionMapper(ctrl)
|
| 161 |
+
|
| 162 |
+
print("测试移动...")
|
| 163 |
+
print("向上移动3秒")
|
| 164 |
+
for i in range(30):
|
| 165 |
+
mapper.execute("move_up")
|
| 166 |
+
time.sleep(0.1)
|
| 167 |
+
|
| 168 |
+
print("停止")
|
| 169 |
+
mapper.execute("move_stop")
|
inference/game_controller.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# game_controller.py
|
| 2 |
+
import subprocess
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import io
|
| 5 |
+
|
| 6 |
+
class GameController:
|
| 7 |
+
def __init__(self, device_id="emulator-5554"):
|
| 8 |
+
self.device_id = device_id
|
| 9 |
+
|
| 10 |
+
def screenshot(self):
|
| 11 |
+
"""截图并返回PIL Image对象"""
|
| 12 |
+
cmd = f"adb -s {self.device_id} exec-out screencap -p"
|
| 13 |
+
output = subprocess.check_output(cmd, shell=True)
|
| 14 |
+
img = Image.open(io.BytesIO(output))
|
| 15 |
+
return img
|
| 16 |
+
|
| 17 |
+
def tap(self, x, y):
|
| 18 |
+
"""点击指定坐标"""
|
| 19 |
+
cmd = f"adb -s {self.device_id} shell input tap {x} {y}"
|
| 20 |
+
subprocess.run(cmd, shell=True)
|
| 21 |
+
|
| 22 |
+
def swipe(self, x1, y1, x2, y2, duration=50):
|
| 23 |
+
"""滑动操作"""
|
| 24 |
+
cmd = f"adb -s {self.device_id} shell input swipe {x1} {y1} {x2} {y2} {duration}"
|
| 25 |
+
subprocess.run(cmd, shell=True)
|
| 26 |
+
|
| 27 |
+
def get_screen_size(self):
|
| 28 |
+
"""获取屏幕分辨率"""
|
| 29 |
+
cmd = f"adb -s {self.device_id} shell wm size"
|
| 30 |
+
output = subprocess.check_output(cmd, shell=True).decode()
|
| 31 |
+
size_str = output.split(":")[1].strip()
|
| 32 |
+
w, h = map(int, size_str.split("x"))
|
| 33 |
+
return w, h
|
inference/run_ai.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference/run_ai.py
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
from inference.action_mapper import GameController, ActionMapper
|
| 10 |
+
from models.king_ai import KingAI
|
| 11 |
+
from data.preprocess import ImageProcessor
|
| 12 |
+
from config import NUM_ACTIONS, ACTIONS
|
| 13 |
+
|
| 14 |
+
class AIPlayer:
|
| 15 |
+
def __init__(self, model_path):
|
| 16 |
+
# 设备
|
| 17 |
+
if torch.backends.mps.is_available():
|
| 18 |
+
self.device = torch.device("mps")
|
| 19 |
+
elif torch.cuda.is_available():
|
| 20 |
+
self.device = torch.device("cuda")
|
| 21 |
+
else:
|
| 22 |
+
self.device = torch.device("cpu")
|
| 23 |
+
|
| 24 |
+
# 模型
|
| 25 |
+
self.model = KingAI(num_actions=NUM_ACTIONS).to(self.device)
|
| 26 |
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
| 27 |
+
self.model.eval()
|
| 28 |
+
|
| 29 |
+
# 控制器
|
| 30 |
+
self.ctrl = GameController()
|
| 31 |
+
self.mapper = ActionMapper(self.ctrl)
|
| 32 |
+
self.processor = ImageProcessor()
|
| 33 |
+
|
| 34 |
+
# 动作持续
|
| 35 |
+
self.current_move = None
|
| 36 |
+
self.move_remaining = 0
|
| 37 |
+
self.MOVE_DURATION_FRAMES = 10 # 移动持续10帧
|
| 38 |
+
|
| 39 |
+
self.frame_rate = 10 # 每秒10帧
|
| 40 |
+
self.inference_interval = 5 # 每5帧推理一次
|
| 41 |
+
|
| 42 |
+
self.frame_count = 0
|
| 43 |
+
self.inference_count = 0
|
| 44 |
+
|
| 45 |
+
print(f"✅ AI 加载完成,设备: {self.device}")
|
| 46 |
+
print(f"移动持续: {self.MOVE_DURATION_FRAMES} 帧 ({self.MOVE_DURATION_FRAMES/self.frame_rate:.1f}秒)")
|
| 47 |
+
|
| 48 |
+
def run(self):
|
| 49 |
+
print("\n🎮 AI 开始,按 Ctrl+C 停止\n")
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
while True:
|
| 53 |
+
# 每 N 帧推理一次
|
| 54 |
+
if self.frame_count % self.inference_interval == 0:
|
| 55 |
+
# 截图
|
| 56 |
+
screen = self.ctrl.screenshot()
|
| 57 |
+
screen_np = np.array(screen)
|
| 58 |
+
processed = self.processor.preprocess(screen_np)
|
| 59 |
+
tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device)
|
| 60 |
+
|
| 61 |
+
# 推理
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
logits = self.model(tensor)
|
| 64 |
+
action_id = torch.argmax(logits, dim=1).item()
|
| 65 |
+
|
| 66 |
+
action = ACTIONS[action_id]
|
| 67 |
+
self.inference_count += 1
|
| 68 |
+
|
| 69 |
+
# 执行动作
|
| 70 |
+
if action.startswith("move_"):
|
| 71 |
+
# 移动动作:设置持续帧数
|
| 72 |
+
self.current_move = action
|
| 73 |
+
self.move_remaining = self.MOVE_DURATION_FRAMES
|
| 74 |
+
print(f"[{self.inference_count}] {action} (持续{self.MOVE_DURATION_FRAMES}帧)")
|
| 75 |
+
else:
|
| 76 |
+
# 攻击/技能:立即执行
|
| 77 |
+
self.mapper.execute(action)
|
| 78 |
+
print(f"[{self.inference_count}] {action}")
|
| 79 |
+
|
| 80 |
+
# 每帧都执行当前移动(保持移动)
|
| 81 |
+
if self.current_move and self.move_remaining > 0:
|
| 82 |
+
self.mapper.execute(self.current_move)
|
| 83 |
+
self.move_remaining -= 1
|
| 84 |
+
|
| 85 |
+
if self.move_remaining == 0:
|
| 86 |
+
# 移动结束,停止
|
| 87 |
+
self.mapper.execute("move_stop")
|
| 88 |
+
self.current_move = None
|
| 89 |
+
print(" 移动停止")
|
| 90 |
+
|
| 91 |
+
self.frame_count += 1
|
| 92 |
+
time.sleep(1.0 / self.frame_rate)
|
| 93 |
+
|
| 94 |
+
except KeyboardInterrupt:
|
| 95 |
+
self.mapper.execute("move_stop")
|
| 96 |
+
print(f"\n✅ 停止,共推理 {self.inference_count} 次")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
model_path = "models/best_model.pth"
|
| 101 |
+
if not os.path.exists(model_path):
|
| 102 |
+
print(f"❌ 模型不存在: {model_path}")
|
| 103 |
+
sys.exit(1)
|
| 104 |
+
|
| 105 |
+
ai = AIPlayer(model_path)
|
| 106 |
+
ai.run()
|
models/best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:800bf47a5776df6cdcdd45d9a555ba8e0ac7a416467dd0a257871accfe1c0b3a
|
| 3 |
+
size 6765301
|
models/final_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:589f9b776ca64fb3e2725786b520bd26cf7b207d02d05d5148e8fd2ea2565d49
|
| 3 |
+
size 6765317
|
models/king_ai.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/king_ai.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
from config import NUM_ACTIONS, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS
|
| 9 |
+
|
| 10 |
+
class KingAI(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
王者荣耀 AI 模型
|
| 13 |
+
输入: (batch, 3, 84, 84) 游戏画面
|
| 14 |
+
输出: (batch, NUM_ACTIONS) 动作概率
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, num_actions=NUM_ACTIONS):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
# 卷积层
|
| 21 |
+
self.conv1 = nn.Conv2d(IMG_CHANNELS, 32, kernel_size=8, stride=4)
|
| 22 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
|
| 23 |
+
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
|
| 24 |
+
|
| 25 |
+
# 计算全连接层输入维度
|
| 26 |
+
self._calculate_fc_dim()
|
| 27 |
+
|
| 28 |
+
# 全连接层
|
| 29 |
+
self.fc1 = nn.Linear(self.fc_input_dim, 512)
|
| 30 |
+
self.fc2 = nn.Linear(512, num_actions)
|
| 31 |
+
|
| 32 |
+
self._initialize_weights()
|
| 33 |
+
|
| 34 |
+
def _calculate_fc_dim(self):
|
| 35 |
+
"""计算卷积层输出维度"""
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
dummy = torch.zeros(1, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH)
|
| 38 |
+
x = F.relu(self.conv1(dummy))
|
| 39 |
+
x = F.relu(self.conv2(x))
|
| 40 |
+
x = F.relu(self.conv3(x))
|
| 41 |
+
self.fc_input_dim = x.view(1, -1).shape[1]
|
| 42 |
+
|
| 43 |
+
def _initialize_weights(self):
|
| 44 |
+
for m in self.modules():
|
| 45 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 46 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 47 |
+
if m.bias is not None:
|
| 48 |
+
nn.init.constant_(m.bias, 0)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
x = F.relu(self.conv1(x))
|
| 52 |
+
x = F.relu(self.conv2(x))
|
| 53 |
+
x = F.relu(self.conv3(x))
|
| 54 |
+
x = x.view(x.size(0), -1)
|
| 55 |
+
x = F.relu(self.fc1(x))
|
| 56 |
+
x = self.fc2(x)
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ActorCritic(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
Actor-Critic 网络,用于强化学习
|
| 63 |
+
共享特征层,分别输出动作概率和状态价值
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, num_actions=NUM_ACTIONS):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.conv1 = nn.Conv2d(IMG_CHANNELS, 32, 8, stride=4)
|
| 70 |
+
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
|
| 71 |
+
self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
|
| 72 |
+
|
| 73 |
+
# 计算维度
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
dummy = torch.zeros(1, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH)
|
| 76 |
+
x = F.relu(self.conv1(dummy))
|
| 77 |
+
x = F.relu(self.conv2(x))
|
| 78 |
+
x = F.relu(self.conv3(x))
|
| 79 |
+
fc_dim = x.view(1, -1).shape[1]
|
| 80 |
+
|
| 81 |
+
self.fc_shared = nn.Linear(fc_dim, 512)
|
| 82 |
+
self.actor = nn.Linear(512, num_actions)
|
| 83 |
+
self.critic = nn.Linear(512, 1)
|
| 84 |
+
|
| 85 |
+
self._initialize_weights()
|
| 86 |
+
|
| 87 |
+
def _initialize_weights(self):
|
| 88 |
+
for m in self.modules():
|
| 89 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 90 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
x = F.relu(self.conv1(x))
|
| 94 |
+
x = F.relu(self.conv2(x))
|
| 95 |
+
x = F.relu(self.conv3(x))
|
| 96 |
+
x = x.view(x.size(0), -1)
|
| 97 |
+
x = F.relu(self.fc_shared(x))
|
| 98 |
+
|
| 99 |
+
action_logits = self.actor(x)
|
| 100 |
+
value = self.critic(x)
|
| 101 |
+
|
| 102 |
+
return action_logits, value
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def test_model():
|
| 106 |
+
"""测试模型输出"""
|
| 107 |
+
model = KingAI()
|
| 108 |
+
dummy = torch.randn(4, 3, 84, 84)
|
| 109 |
+
output = model(dummy)
|
| 110 |
+
print(f"KingAI - 输入: {dummy.shape}, 输出: {output.shape}")
|
| 111 |
+
print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")
|
| 112 |
+
|
| 113 |
+
ac_model = ActorCritic()
|
| 114 |
+
logits, values = ac_model(dummy)
|
| 115 |
+
print(f"ActorCritic - logits: {logits.shape}, values: {values.shape}")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
test_model()
|
models/train_bc.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/train_bc.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
+
|
| 10 |
+
from config import NUM_EPOCHS, LEARNING_RATE, MODEL_DIR, DEVICE
|
| 11 |
+
from models.king_ai import KingAI
|
| 12 |
+
from data.dataset import get_dataloaders
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def train():
|
| 16 |
+
"""训练行为克隆模型"""
|
| 17 |
+
# 检测设备
|
| 18 |
+
if torch.backends.mps.is_available():
|
| 19 |
+
device = torch.device("mps")
|
| 20 |
+
print("✅ 使用 MPS (Apple Silicon GPU) 加速")
|
| 21 |
+
elif torch.cuda.is_available():
|
| 22 |
+
device = torch.device("cuda")
|
| 23 |
+
print("✅ 使用 CUDA (NVIDIA GPU) 加速")
|
| 24 |
+
else:
|
| 25 |
+
device = torch.device("cpu")
|
| 26 |
+
print("⚠️ 使用 CPU 训练")
|
| 27 |
+
|
| 28 |
+
# 加载数据
|
| 29 |
+
print("\n加载数据...")
|
| 30 |
+
train_loader, val_loader = get_dataloaders(
|
| 31 |
+
frames_dir="data/frames/game_01",
|
| 32 |
+
annotation_file="data/annotations/annotations.json"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# 创建模型
|
| 36 |
+
model = KingAI().to(device)
|
| 37 |
+
criterion = nn.CrossEntropyLoss()
|
| 38 |
+
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
| 39 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
|
| 40 |
+
|
| 41 |
+
print(f"\n开始训练 {NUM_EPOCHS} 轮...")
|
| 42 |
+
print("=" * 50)
|
| 43 |
+
|
| 44 |
+
best_acc = 0.0
|
| 45 |
+
|
| 46 |
+
for epoch in range(NUM_EPOCHS):
|
| 47 |
+
# 训练阶段
|
| 48 |
+
model.train()
|
| 49 |
+
train_loss = 0.0
|
| 50 |
+
train_correct = 0
|
| 51 |
+
train_total = 0
|
| 52 |
+
|
| 53 |
+
for images, actions in train_loader:
|
| 54 |
+
images, actions = images.to(device), actions.to(device)
|
| 55 |
+
|
| 56 |
+
optimizer.zero_grad()
|
| 57 |
+
outputs = model(images)
|
| 58 |
+
loss = criterion(outputs, actions)
|
| 59 |
+
loss.backward()
|
| 60 |
+
optimizer.step()
|
| 61 |
+
|
| 62 |
+
train_loss += loss.item()
|
| 63 |
+
_, predicted = torch.max(outputs, 1)
|
| 64 |
+
train_total += actions.size(0)
|
| 65 |
+
train_correct += (predicted == actions).sum().item()
|
| 66 |
+
|
| 67 |
+
train_acc = 100 * train_correct / train_total
|
| 68 |
+
|
| 69 |
+
# 验证阶段
|
| 70 |
+
model.eval()
|
| 71 |
+
val_loss = 0.0
|
| 72 |
+
val_correct = 0
|
| 73 |
+
val_total = 0
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
for images, actions in val_loader:
|
| 77 |
+
images, actions = images.to(device), actions.to(device)
|
| 78 |
+
outputs = model(images)
|
| 79 |
+
loss = criterion(outputs, actions)
|
| 80 |
+
|
| 81 |
+
val_loss += loss.item()
|
| 82 |
+
_, predicted = torch.max(outputs, 1)
|
| 83 |
+
val_total += actions.size(0)
|
| 84 |
+
val_correct += (predicted == actions).sum().item()
|
| 85 |
+
|
| 86 |
+
val_acc = 100 * val_correct / val_total
|
| 87 |
+
|
| 88 |
+
scheduler.step()
|
| 89 |
+
|
| 90 |
+
print(f"Epoch [{epoch+1:3d}/{NUM_EPOCHS}] "
|
| 91 |
+
f"Train Loss: {train_loss/len(train_loader):.4f} "
|
| 92 |
+
f"Train Acc: {train_acc:.2f}% | "
|
| 93 |
+
f"Val Loss: {val_loss/len(val_loader):.4f} "
|
| 94 |
+
f"Val Acc: {val_acc:.2f}%")
|
| 95 |
+
|
| 96 |
+
# 保存最佳模型
|
| 97 |
+
if val_acc > best_acc:
|
| 98 |
+
best_acc = val_acc
|
| 99 |
+
torch.save(model.state_dict(), os.path.join(MODEL_DIR, "best_model.pth"))
|
| 100 |
+
print(f" ✅ 保存最佳模型 (准确率: {val_acc:.2f}%)")
|
| 101 |
+
|
| 102 |
+
# 保存最终模型
|
| 103 |
+
torch.save(model.state_dict(), os.path.join(MODEL_DIR, "final_model.pth"))
|
| 104 |
+
print(f"\n🎉 训练完成!最佳验证准确率: {best_acc:.2f}%")
|
| 105 |
+
print(f"模型保存在: {MODEL_DIR}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
train()
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
opencv-python>=4.8.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
pillow>=10.0.0
|
| 6 |
+
matplotlib>=3.7.0
|
| 7 |
+
jupyter>=1.0.0
|
tests/test_mapper.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_mapper.py
|
| 2 |
+
from inference.game_controller import GameController
|
| 3 |
+
from inference.action_mapper import ActionMapper
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def test_all_actions():
|
| 8 |
+
"""测试所有动作映射"""
|
| 9 |
+
ctrl = GameController()
|
| 10 |
+
mapper = ActionMapper(ctrl)
|
| 11 |
+
|
| 12 |
+
print("=" * 50)
|
| 13 |
+
print("王者荣耀AI控制测试 - 全面测试")
|
| 14 |
+
print("=" * 50)
|
| 15 |
+
|
| 16 |
+
# 1. 测试移动(8个方向)
|
| 17 |
+
print("\n[1/4] 测试移动控制...")
|
| 18 |
+
moves = [
|
| 19 |
+
"move_up", "move_down", "move_left", "move_right",
|
| 20 |
+
"move_upleft", "move_upright", "move_downleft", "move_downright"
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
for move in moves:
|
| 24 |
+
print(f" 执行: {move}")
|
| 25 |
+
mapper.execute(move)
|
| 26 |
+
time.sleep(0.5) # 每个动作间隔0.5秒
|
| 27 |
+
|
| 28 |
+
# 2. 测试战斗技能
|
| 29 |
+
print("\n[2/4] 测试战斗技能...")
|
| 30 |
+
combat_actions = ["attack", "skill_damage", "skill_control"]
|
| 31 |
+
|
| 32 |
+
for action in combat_actions:
|
| 33 |
+
print(f" 执行: {action}")
|
| 34 |
+
mapper.execute(action)
|
| 35 |
+
time.sleep(0.8)
|
| 36 |
+
|
| 37 |
+
# 3. 测试战术动作
|
| 38 |
+
print("\n[3/4] 测试战术动作...")
|
| 39 |
+
tactical_actions = ["recall", "heal", "summoner", "enhance", "upgrade"]
|
| 40 |
+
|
| 41 |
+
for action in tactical_actions:
|
| 42 |
+
print(f" 执行: {action}")
|
| 43 |
+
mapper.execute(action)
|
| 44 |
+
time.sleep(0.8)
|
| 45 |
+
|
| 46 |
+
# 4. 测试截图功能
|
| 47 |
+
print("\n[4/4] 测试截图功能...")
|
| 48 |
+
try:
|
| 49 |
+
img = ctrl.screenshot()
|
| 50 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 51 |
+
filename = f"screenshot_{timestamp}.png"
|
| 52 |
+
img.save(filename)
|
| 53 |
+
print(f" 截图已保存: {filename}")
|
| 54 |
+
|
| 55 |
+
# 获取屏幕尺寸
|
| 56 |
+
w, h = ctrl.get_screen_size()
|
| 57 |
+
print(f" 屏幕尺寸: {w}x{h}")
|
| 58 |
+
|
| 59 |
+
# 显示截图信息
|
| 60 |
+
print(f" 图片尺寸: {img.size}")
|
| 61 |
+
print(f" 图片模式: {img.mode}")
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f" 截图失败: {e}")
|
| 65 |
+
|
| 66 |
+
print("\n" + "=" * 50)
|
| 67 |
+
print("测试完成!")
|
| 68 |
+
print("=" * 50)
|
| 69 |
+
|
| 70 |
+
def test_single_action():
|
| 71 |
+
"""交互式测试单个动作"""
|
| 72 |
+
ctrl = GameController()
|
| 73 |
+
mapper = ActionMapper(ctrl)
|
| 74 |
+
|
| 75 |
+
print("\n=== 交互式测试模式 ===")
|
| 76 |
+
print("可用动作:")
|
| 77 |
+
print(" 移动: up, down, left, right, upleft, upright, downleft, downright")
|
| 78 |
+
print(" 战斗: attack, damage, control")
|
| 79 |
+
print(" 战术: recall, heal, summoner, enhance, upgrade")
|
| 80 |
+
print(" 其他: screenshot, quit")
|
| 81 |
+
print("-" * 40)
|
| 82 |
+
|
| 83 |
+
while True:
|
| 84 |
+
cmd = input("\n请输入动作: ").strip().lower()
|
| 85 |
+
|
| 86 |
+
if cmd == 'quit':
|
| 87 |
+
print("退出测试")
|
| 88 |
+
break
|
| 89 |
+
elif cmd == 'screenshot':
|
| 90 |
+
try:
|
| 91 |
+
img = ctrl.screenshot()
|
| 92 |
+
filename = "manual_screenshot.png"
|
| 93 |
+
img.save(filename)
|
| 94 |
+
print(f"截图已保存: {filename}")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"截图失败: {e}")
|
| 97 |
+
elif cmd in ['up', 'down', 'left', 'right', 'upleft', 'upright', 'downleft', 'downright']:
|
| 98 |
+
action = f"move_{cmd}"
|
| 99 |
+
print(f"执行: {action}")
|
| 100 |
+
mapper.execute(action)
|
| 101 |
+
elif cmd in ['attack', 'damage', 'control']:
|
| 102 |
+
if cmd == 'damage':
|
| 103 |
+
action = 'skill_damage'
|
| 104 |
+
elif cmd == 'control':
|
| 105 |
+
action = 'skill_control'
|
| 106 |
+
else:
|
| 107 |
+
action = cmd
|
| 108 |
+
print(f"执行: {action}")
|
| 109 |
+
mapper.execute(action)
|
| 110 |
+
elif cmd in ['recall', 'heal', 'summoner', 'enhance', 'upgrade']:
|
| 111 |
+
print(f"执行: {cmd}")
|
| 112 |
+
mapper.execute(cmd)
|
| 113 |
+
else:
|
| 114 |
+
print(f"未知动作: {cmd}")
|
| 115 |
+
|
| 116 |
+
time.sleep(0.3)
|
| 117 |
+
|
| 118 |
+
def test_with_delay():
|
| 119 |
+
"""带延迟的循环测试(用于观察)"""
|
| 120 |
+
ctrl = GameController()
|
| 121 |
+
mapper = ActionMapper(ctrl)
|
| 122 |
+
|
| 123 |
+
print("\n=== 自动循环测试模式 ===")
|
| 124 |
+
print("将依次执行所有动作,每个动作间隔1秒")
|
| 125 |
+
print("按 Ctrl+C 停止\n")
|
| 126 |
+
|
| 127 |
+
actions = [
|
| 128 |
+
# 移动
|
| 129 |
+
"move_up", "move_down", "move_left", "move_right",
|
| 130 |
+
"move_upleft", "move_upright", "move_downleft", "move_downright",
|
| 131 |
+
# 战斗
|
| 132 |
+
"attack", "skill_damage", "skill_control",
|
| 133 |
+
# 战术
|
| 134 |
+
"recall", "heal", "summoner", "enhance", "upgrade"
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
for i, action in enumerate(actions, 1):
|
| 139 |
+
print(f"[{i}/{len(actions)}] {action}")
|
| 140 |
+
mapper.execute(action)
|
| 141 |
+
time.sleep(1) # 每个动作间隔1秒
|
| 142 |
+
|
| 143 |
+
print("\n循环测试完成")
|
| 144 |
+
except KeyboardInterrupt:
|
| 145 |
+
print("\n用户中断测试")
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
print("选择测试模式:")
|
| 149 |
+
print("1. 全面测试(一次执行所有动作)")
|
| 150 |
+
print("2. 交互式测试(手动输入动作)")
|
| 151 |
+
print("3. 循环测试(自动循环,可观察)")
|
| 152 |
+
|
| 153 |
+
choice = input("\n请选择 (1/2/3): ").strip()
|
| 154 |
+
|
| 155 |
+
if choice == '1':
|
| 156 |
+
test_all_actions()
|
| 157 |
+
elif choice == '2':
|
| 158 |
+
test_single_action()
|
| 159 |
+
elif choice == '3':
|
| 160 |
+
test_with_delay()
|
| 161 |
+
else:
|
| 162 |
+
print("无效选择,运行全面测试")
|
| 163 |
+
test_all_actions()
|