clarenceleo commited on
Commit
dd9e164
·
0 Parent(s):

Duplicate from clarenceleo/HonorNet_v1

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 [Tianyi Li]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HonorNet 🎮
2
+
3
+ > 一个基于行为克隆的王者荣耀AI,从零开始,在Mac上训练,在Android模拟器中运行。
4
+
5
+ [![MIT License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
6
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/)
7
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org/)
8
+
9
+ ## ✨ 项目简介
10
+
11
+ **HonorNet** 是一个完全开源的王者荣耀AI项目。它的特点:
12
+
13
+ - 🍎 **在Mac上训练**:利用Apple Silicon的MPS加速
14
+ - 📱 **控制Android模拟器**:通过ADB截图和发送触摸事件
15
+ - 🧠 **行为克隆**:学习人类玩家的操作数据
16
+ - 🎯 **16个动作空间**:移动8方向 + 技能 + 攻击 + 战术
17
+ - 🔓 **MIT协议**:完全开源,任意使用
18
+
19
+ 这不是一个“调用API”的demo,而是一个**从数据采集到模型训练再到部署**的完整工程。
20
+
21
+ ## 🎯 项目状态
22
+
23
+ | 阶段 | 状态 |
24
+ |------|------|
25
+ | 数据采集 | ✅ 完成(886帧标注) |
26
+ | 行为克隆训练 | ✅ 完成(验证准确率54.5%) |
27
+ | 模拟器部署 | ✅ 完成 |
28
+ | 强化学习微调 | 🚧 进行中 |
29
+
30
+ ## 🏗️ 项目结构
31
+
32
+ ```
33
+ HonorNet/
34
+ ├── data/ # 数据处理脚本
35
+ ├── models/ # 模型定义 + 训练
36
+ ├── inference/ # 模拟器控制 + AI运行
37
+ ├── config.py # 配置文件
38
+ └── requirements.txt # 依赖
39
+ ```
40
+
41
+ ## 🚀 快速开始
42
+
43
+ ### 1. 环境配置
44
+
45
+ ```bash
46
+ # 克隆仓库
47
+ git lfs clone https://huggingface.co/clarenceleo/HonorNet_v1
48
+ cd HonorNet_v1
49
+
50
+ # 安装依赖
51
+ pip install -r requirements.txt
52
+
53
+ # 安装ADB(macOS)
54
+ brew install android-platform-tools
55
+ ```
56
+
57
+ ### 2. 准备数据
58
+
59
+ 将王者荣耀1v1录屏放入 `data/raw_videos/`,然后:
60
+
61
+ ```bash
62
+ # 抽帧
63
+ python data/extract_frames.py
64
+
65
+ # 预处理
66
+ python data/preprocess.py
67
+
68
+ # 标注动作(可选,我们提供了标注工具)
69
+ python data/annotate.py
70
+ ```
71
+
72
+ ### 3. 训练模型
73
+
74
+ ```bash
75
+ python models/train_bc.py
76
+ ```
77
+
78
+ ### 4. 让AI打游戏
79
+
80
+ 1. 启动Android模拟器,打开王者荣耀1v1模式
81
+ 2. 运行AI:
82
+
83
+ ```bash
84
+ python inference/run_ai.py
85
+ ```
86
+
87
+ ## 📊 训练结果
88
+
89
+ 使用886帧标注数据训练50轮:
90
+
91
+ | 指标 | 数值 |
92
+ |------|------|
93
+ | 最佳验证准确率 | **54.49%** |
94
+ | 随机基线 | 6.25% |
95
+ | 训练准确率 | 95.48% |
96
+
97
+ **54.5%的准确率意味着AI学会了**:
98
+ - 根据画面判断移动方向
99
+ - 何时普攻、放技能
100
+ - 何时回城、升级
101
+
102
+ ## 🛠️ 技术栈
103
+
104
+ - **PyTorch**:深度学习框架
105
+ - **ADB**:Android调试桥,控制模拟器
106
+ - **OpenCV**:图像处理
107
+ - **Android Studio AVD**:模拟器运行环境
108
+
109
+ ## 📈 后续计划
110
+
111
+ - [ ] 强化学习微调(PPO)
112
+ - [ ] 支持更多英雄
113
+ - [ ] 5v5多智能体
114
+ - [ ] 实时学习(边打边学)
115
+
116
+ ## 🤝 贡献
117
+
118
+ 模型开发与训练数据清洗标注:李天祎(1637321445@qq.com)
119
+ 训练数据录制与提供:姜懿原
120
+
121
+ 欢迎提交Issue和Pull Request!
config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import os
3
+
4
+ # 路径配置
5
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
6
+ DATA_DIR = os.path.join(BASE_DIR, "data")
7
+ RAW_VIDEO_DIR = os.path.join(DATA_DIR, "raw_videos")
8
+ FRAMES_DIR = os.path.join(DATA_DIR, "frames")
9
+ ANNOTATIONS_DIR = os.path.join(DATA_DIR, "annotations")
10
+ PROCESSED_DIR = os.path.join(DATA_DIR, "processed")
11
+ MODEL_DIR = os.path.join(BASE_DIR, "models")
12
+
13
+ # 创建目录
14
+ os.makedirs(RAW_VIDEO_DIR, exist_ok=True)
15
+ os.makedirs(FRAMES_DIR, exist_ok=True)
16
+ os.makedirs(ANNOTATIONS_DIR, exist_ok=True)
17
+ os.makedirs(PROCESSED_DIR, exist_ok=True)
18
+ os.makedirs(MODEL_DIR, exist_ok=True)
19
+
20
+ # 图像处理参数
21
+ IMG_HEIGHT = 84
22
+ IMG_WIDTH = 84
23
+ IMG_CHANNELS = 3
24
+ CROP_TOP_RATIO = 0.08 # 裁剪顶部8%(去掉状态栏)
25
+ CROP_BOTTOM_RATIO = 0.05 # 裁剪底部5%(去掉按钮栏)
26
+
27
+ # 视频抽帧参数
28
+ EXTRACT_FPS = 5 # 每秒抽5帧
29
+
30
+ # 动作空间(16个动作)
31
+ ACTIONS = [
32
+ 'move_up', 'move_down', 'move_left', 'move_right',
33
+ 'move_upleft', 'move_upright', 'move_downleft', 'move_downright',
34
+ 'attack', 'skill_1', 'skill_2', 'skill_3',
35
+ 'recall', 'heal', 'summoner', 'upgrade'
36
+ ]
37
+
38
+ NUM_ACTIONS = len(ACTIONS)
39
+
40
+ # 训练参数
41
+ BATCH_SIZE = 64
42
+ LEARNING_RATE = 0.0001
43
+ NUM_EPOCHS = 50
44
+ TRAIN_SPLIT = 0.8
45
+
46
+ # 设备
47
+ DEVICE = None # 运行时自动检测
inference/action_mapper.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference/action_mapper.py
2
+ import subprocess
3
+ import time
4
+ import io
5
+ from PIL import Image
6
+
7
+ class GameController:
8
+ """ADB 游戏控制器"""
9
+
10
+ def __init__(self, device_id="emulator-5554"):
11
+ self.device_id = device_id
12
+
13
+ def tap(self, x, y):
14
+ """点击"""
15
+ cmd = f"adb -s {self.device_id} shell input tap {x} {y}"
16
+ subprocess.run(cmd, shell=True)
17
+
18
+ def swipe(self, x1, y1, x2, y2, duration=50):
19
+ """滑动"""
20
+ cmd = f"adb -s {self.device_id} shell input swipe {x1} {y1} {x2} {y2} {duration}"
21
+ subprocess.run(cmd, shell=True)
22
+
23
+ def swipe_continuous(self, x1, y1, x2, y2, duration_ms=50):
24
+ """连续滑动(用于保持移动)"""
25
+ cmd = f"adb -s {self.device_id} shell input swipe {x1} {y1} {x2} {y2} {duration_ms}"
26
+ subprocess.run(cmd, shell=True)
27
+
28
+ def screenshot(self):
29
+ """截图"""
30
+ cmd = f"adb -s {self.device_id} exec-out screencap -p"
31
+ output = subprocess.check_output(cmd, shell=True)
32
+ return Image.open(io.BytesIO(output))
33
+
34
+ def get_screen_size(self):
35
+ """获取屏幕分辨率"""
36
+ cmd = f"adb -s {self.device_id} shell wm size"
37
+ output = subprocess.check_output(cmd, shell=True).decode()
38
+ size_str = output.split(":")[1].strip()
39
+ w, h = map(int, size_str.split("x"))
40
+ return w, h
41
+
42
+
43
+ class ActionMapper:
44
+ """
45
+ 动作映射器 - 使用持续滑动保持移动
46
+ """
47
+
48
+ def __init__(self, controller):
49
+ self.ctrl = controller
50
+
51
+ # 按钮坐标
52
+ self.buttons = {
53
+ "joystick_center": (448, 861),
54
+ "attack": (1936, 925),
55
+ "skill_1": (1723, 750),
56
+ "skill_2": (1927, 635),
57
+ "skill_3": (1443, 969),
58
+ "recall": (1150, 979),
59
+ "heal": (1283, 979),
60
+ "summoner": (1443, 969),
61
+ "upgrade": (1513, 833),
62
+ }
63
+
64
+ # 移动目标坐标(摇杆应该滑到的位置)
65
+ self.move_targets = {
66
+ "up": (448, 741),
67
+ "down": (448, 981),
68
+ "left": (328, 861),
69
+ "right": (568, 861),
70
+ "upleft": (363, 776),
71
+ "upright": (533, 776),
72
+ "downleft": (363, 946),
73
+ "downright": (533, 946),
74
+ }
75
+
76
+ # 当前状态
77
+ self.current_direction = None
78
+
79
+ # 动作映射
80
+ self.action_to_button = {
81
+ "attack": "attack",
82
+ "skill_1": "skill_1",
83
+ "skill_2": "skill_2",
84
+ "skill_3": "skill_3",
85
+ "summoner": "summoner",
86
+ "recall": "recall",
87
+ "heal": "heal",
88
+ "upgrade": "upgrade",
89
+ }
90
+
91
+ def execute(self, action_name):
92
+ """
93
+ 执行动作(每帧调用)
94
+ """
95
+ # 移动动作
96
+ if action_name.startswith("move_"):
97
+ direction = action_name.replace("move_", "")
98
+ if direction in self.move_targets:
99
+ self._do_move(direction)
100
+ return
101
+
102
+ # 停止移动
103
+ if action_name == "move_stop":
104
+ self._stop_move()
105
+ return
106
+
107
+ # 点击动作
108
+ if action_name in self.action_to_button:
109
+ button = self.action_to_button[action_name]
110
+ if button in self.buttons:
111
+ x, y = self.buttons[button]
112
+ self.ctrl.tap(x, y)
113
+
114
+ def _do_move(self, direction):
115
+ """
116
+ 执行移动:每帧都滑动到目标位置
117
+ 这样才能保持英雄持续移动
118
+ """
119
+ cx, cy = self.buttons["joystick_center"]
120
+ tx, ty = self.move_targets[direction]
121
+
122
+ # 每帧都执行滑动,保持摇杆位置
123
+ self.ctrl.swipe(cx, cy, tx, ty, duration=30)
124
+ self.current_direction = direction
125
+
126
+ def _stop_move(self):
127
+ """停止移动:摇杆回中心"""
128
+ cx, cy = self.buttons["joystick_center"]
129
+ self.ctrl.swipe(cx, cy, cx, cy, duration=30)
130
+ self.current_direction = None
131
+
132
+ def attack(self):
133
+ """普攻"""
134
+ x, y = self.buttons["attack"]
135
+ self.ctrl.tap(x, y)
136
+
137
+ def skill_1(self):
138
+ """技能1"""
139
+ x, y = self.buttons["skill_1"]
140
+ self.ctrl.tap(x, y)
141
+
142
+ def skill_2(self):
143
+ """技能2"""
144
+ x, y = self.buttons["skill_2"]
145
+ self.ctrl.tap(x, y)
146
+
147
+ def skill_3(self):
148
+ """技能3"""
149
+ x, y = self.buttons["skill_3"]
150
+ self.ctrl.tap(x, y)
151
+
152
+ def recall(self):
153
+ """回城"""
154
+ x, y = self.buttons["recall"]
155
+ self.ctrl.tap(x, y)
156
+
157
+
158
+ if __name__ == "__main__":
159
+ ctrl = GameController()
160
+ mapper = ActionMapper(ctrl)
161
+
162
+ print("测试移动...")
163
+ print("向上移动3秒")
164
+ for i in range(30):
165
+ mapper.execute("move_up")
166
+ time.sleep(0.1)
167
+
168
+ print("停止")
169
+ mapper.execute("move_stop")
inference/game_controller.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # game_controller.py
2
+ import subprocess
3
+ from PIL import Image
4
+ import io
5
+
6
+ class GameController:
7
+ def __init__(self, device_id="emulator-5554"):
8
+ self.device_id = device_id
9
+
10
+ def screenshot(self):
11
+ """截图并返回PIL Image对象"""
12
+ cmd = f"adb -s {self.device_id} exec-out screencap -p"
13
+ output = subprocess.check_output(cmd, shell=True)
14
+ img = Image.open(io.BytesIO(output))
15
+ return img
16
+
17
+ def tap(self, x, y):
18
+ """点击指定坐标"""
19
+ cmd = f"adb -s {self.device_id} shell input tap {x} {y}"
20
+ subprocess.run(cmd, shell=True)
21
+
22
+ def swipe(self, x1, y1, x2, y2, duration=50):
23
+ """滑动操作"""
24
+ cmd = f"adb -s {self.device_id} shell input swipe {x1} {y1} {x2} {y2} {duration}"
25
+ subprocess.run(cmd, shell=True)
26
+
27
+ def get_screen_size(self):
28
+ """获取屏幕分辨率"""
29
+ cmd = f"adb -s {self.device_id} shell wm size"
30
+ output = subprocess.check_output(cmd, shell=True).decode()
31
+ size_str = output.split(":")[1].strip()
32
+ w, h = map(int, size_str.split("x"))
33
+ return w, h
inference/run_ai.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference/run_ai.py
2
+ import sys
3
+ import os
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+
6
+ import torch
7
+ import time
8
+ import numpy as np
9
+ from inference.action_mapper import GameController, ActionMapper
10
+ from models.king_ai import KingAI
11
+ from data.preprocess import ImageProcessor
12
+ from config import NUM_ACTIONS, ACTIONS
13
+
14
+ class AIPlayer:
15
+ def __init__(self, model_path):
16
+ # 设备
17
+ if torch.backends.mps.is_available():
18
+ self.device = torch.device("mps")
19
+ elif torch.cuda.is_available():
20
+ self.device = torch.device("cuda")
21
+ else:
22
+ self.device = torch.device("cpu")
23
+
24
+ # 模型
25
+ self.model = KingAI(num_actions=NUM_ACTIONS).to(self.device)
26
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
27
+ self.model.eval()
28
+
29
+ # 控制器
30
+ self.ctrl = GameController()
31
+ self.mapper = ActionMapper(self.ctrl)
32
+ self.processor = ImageProcessor()
33
+
34
+ # 动作持续
35
+ self.current_move = None
36
+ self.move_remaining = 0
37
+ self.MOVE_DURATION_FRAMES = 10 # 移动持续10帧
38
+
39
+ self.frame_rate = 10 # 每秒10帧
40
+ self.inference_interval = 5 # 每5帧推理一次
41
+
42
+ self.frame_count = 0
43
+ self.inference_count = 0
44
+
45
+ print(f"✅ AI 加载完成,设备: {self.device}")
46
+ print(f"移动持续: {self.MOVE_DURATION_FRAMES} 帧 ({self.MOVE_DURATION_FRAMES/self.frame_rate:.1f}秒)")
47
+
48
+ def run(self):
49
+ print("\n🎮 AI 开始,按 Ctrl+C 停止\n")
50
+
51
+ try:
52
+ while True:
53
+ # 每 N 帧推理一次
54
+ if self.frame_count % self.inference_interval == 0:
55
+ # 截图
56
+ screen = self.ctrl.screenshot()
57
+ screen_np = np.array(screen)
58
+ processed = self.processor.preprocess(screen_np)
59
+ tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device)
60
+
61
+ # 推理
62
+ with torch.no_grad():
63
+ logits = self.model(tensor)
64
+ action_id = torch.argmax(logits, dim=1).item()
65
+
66
+ action = ACTIONS[action_id]
67
+ self.inference_count += 1
68
+
69
+ # 执行动作
70
+ if action.startswith("move_"):
71
+ # 移动动作:设置持续帧数
72
+ self.current_move = action
73
+ self.move_remaining = self.MOVE_DURATION_FRAMES
74
+ print(f"[{self.inference_count}] {action} (持续{self.MOVE_DURATION_FRAMES}帧)")
75
+ else:
76
+ # 攻击/技能:立即执行
77
+ self.mapper.execute(action)
78
+ print(f"[{self.inference_count}] {action}")
79
+
80
+ # 每帧都执行当前移动(保持移动)
81
+ if self.current_move and self.move_remaining > 0:
82
+ self.mapper.execute(self.current_move)
83
+ self.move_remaining -= 1
84
+
85
+ if self.move_remaining == 0:
86
+ # 移动结束,停止
87
+ self.mapper.execute("move_stop")
88
+ self.current_move = None
89
+ print(" 移动停止")
90
+
91
+ self.frame_count += 1
92
+ time.sleep(1.0 / self.frame_rate)
93
+
94
+ except KeyboardInterrupt:
95
+ self.mapper.execute("move_stop")
96
+ print(f"\n✅ 停止,共推理 {self.inference_count} 次")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ model_path = "models/best_model.pth"
101
+ if not os.path.exists(model_path):
102
+ print(f"❌ 模型不存在: {model_path}")
103
+ sys.exit(1)
104
+
105
+ ai = AIPlayer(model_path)
106
+ ai.run()
models/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:800bf47a5776df6cdcdd45d9a555ba8e0ac7a416467dd0a257871accfe1c0b3a
3
+ size 6765301
models/final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:589f9b776ca64fb3e2725786b520bd26cf7b207d02d05d5148e8fd2ea2565d49
3
+ size 6765317
models/king_ai.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/king_ai.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import sys
6
+ import os
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+ from config import NUM_ACTIONS, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS
9
+
10
+ class KingAI(nn.Module):
11
+ """
12
+ 王者荣耀 AI 模型
13
+ 输入: (batch, 3, 84, 84) 游戏画面
14
+ 输出: (batch, NUM_ACTIONS) 动作概率
15
+ """
16
+
17
+ def __init__(self, num_actions=NUM_ACTIONS):
18
+ super().__init__()
19
+
20
+ # 卷积层
21
+ self.conv1 = nn.Conv2d(IMG_CHANNELS, 32, kernel_size=8, stride=4)
22
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
23
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
24
+
25
+ # 计算全连接层输入维度
26
+ self._calculate_fc_dim()
27
+
28
+ # 全连接层
29
+ self.fc1 = nn.Linear(self.fc_input_dim, 512)
30
+ self.fc2 = nn.Linear(512, num_actions)
31
+
32
+ self._initialize_weights()
33
+
34
+ def _calculate_fc_dim(self):
35
+ """计算卷积层输出维度"""
36
+ with torch.no_grad():
37
+ dummy = torch.zeros(1, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH)
38
+ x = F.relu(self.conv1(dummy))
39
+ x = F.relu(self.conv2(x))
40
+ x = F.relu(self.conv3(x))
41
+ self.fc_input_dim = x.view(1, -1).shape[1]
42
+
43
+ def _initialize_weights(self):
44
+ for m in self.modules():
45
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
46
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
47
+ if m.bias is not None:
48
+ nn.init.constant_(m.bias, 0)
49
+
50
+ def forward(self, x):
51
+ x = F.relu(self.conv1(x))
52
+ x = F.relu(self.conv2(x))
53
+ x = F.relu(self.conv3(x))
54
+ x = x.view(x.size(0), -1)
55
+ x = F.relu(self.fc1(x))
56
+ x = self.fc2(x)
57
+ return x
58
+
59
+
60
+ class ActorCritic(nn.Module):
61
+ """
62
+ Actor-Critic 网络,用于强化学习
63
+ 共享特征层,分别输出动作概率和状态价值
64
+ """
65
+
66
+ def __init__(self, num_actions=NUM_ACTIONS):
67
+ super().__init__()
68
+
69
+ self.conv1 = nn.Conv2d(IMG_CHANNELS, 32, 8, stride=4)
70
+ self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
71
+ self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
72
+
73
+ # 计算维度
74
+ with torch.no_grad():
75
+ dummy = torch.zeros(1, IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH)
76
+ x = F.relu(self.conv1(dummy))
77
+ x = F.relu(self.conv2(x))
78
+ x = F.relu(self.conv3(x))
79
+ fc_dim = x.view(1, -1).shape[1]
80
+
81
+ self.fc_shared = nn.Linear(fc_dim, 512)
82
+ self.actor = nn.Linear(512, num_actions)
83
+ self.critic = nn.Linear(512, 1)
84
+
85
+ self._initialize_weights()
86
+
87
+ def _initialize_weights(self):
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
90
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
91
+
92
+ def forward(self, x):
93
+ x = F.relu(self.conv1(x))
94
+ x = F.relu(self.conv2(x))
95
+ x = F.relu(self.conv3(x))
96
+ x = x.view(x.size(0), -1)
97
+ x = F.relu(self.fc_shared(x))
98
+
99
+ action_logits = self.actor(x)
100
+ value = self.critic(x)
101
+
102
+ return action_logits, value
103
+
104
+
105
+ def test_model():
106
+ """测试模型输出"""
107
+ model = KingAI()
108
+ dummy = torch.randn(4, 3, 84, 84)
109
+ output = model(dummy)
110
+ print(f"KingAI - 输入: {dummy.shape}, 输出: {output.shape}")
111
+ print(f"参数量: {sum(p.numel() for p in model.parameters()):,}")
112
+
113
+ ac_model = ActorCritic()
114
+ logits, values = ac_model(dummy)
115
+ print(f"ActorCritic - logits: {logits.shape}, values: {values.shape}")
116
+
117
+
118
+ if __name__ == "__main__":
119
+ test_model()
models/train_bc.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/train_bc.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import numpy as np
6
+ import os
7
+ import sys
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+
10
+ from config import NUM_EPOCHS, LEARNING_RATE, MODEL_DIR, DEVICE
11
+ from models.king_ai import KingAI
12
+ from data.dataset import get_dataloaders
13
+
14
+
15
+ def train():
16
+ """训练行为克隆模型"""
17
+ # 检测设备
18
+ if torch.backends.mps.is_available():
19
+ device = torch.device("mps")
20
+ print("✅ 使用 MPS (Apple Silicon GPU) 加速")
21
+ elif torch.cuda.is_available():
22
+ device = torch.device("cuda")
23
+ print("✅ 使用 CUDA (NVIDIA GPU) 加速")
24
+ else:
25
+ device = torch.device("cpu")
26
+ print("⚠️ 使用 CPU 训练")
27
+
28
+ # 加载数据
29
+ print("\n加载数据...")
30
+ train_loader, val_loader = get_dataloaders(
31
+ frames_dir="data/frames/game_01",
32
+ annotation_file="data/annotations/annotations.json"
33
+ )
34
+
35
+ # 创建模型
36
+ model = KingAI().to(device)
37
+ criterion = nn.CrossEntropyLoss()
38
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
39
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
40
+
41
+ print(f"\n开始训练 {NUM_EPOCHS} 轮...")
42
+ print("=" * 50)
43
+
44
+ best_acc = 0.0
45
+
46
+ for epoch in range(NUM_EPOCHS):
47
+ # 训练阶段
48
+ model.train()
49
+ train_loss = 0.0
50
+ train_correct = 0
51
+ train_total = 0
52
+
53
+ for images, actions in train_loader:
54
+ images, actions = images.to(device), actions.to(device)
55
+
56
+ optimizer.zero_grad()
57
+ outputs = model(images)
58
+ loss = criterion(outputs, actions)
59
+ loss.backward()
60
+ optimizer.step()
61
+
62
+ train_loss += loss.item()
63
+ _, predicted = torch.max(outputs, 1)
64
+ train_total += actions.size(0)
65
+ train_correct += (predicted == actions).sum().item()
66
+
67
+ train_acc = 100 * train_correct / train_total
68
+
69
+ # 验证阶段
70
+ model.eval()
71
+ val_loss = 0.0
72
+ val_correct = 0
73
+ val_total = 0
74
+
75
+ with torch.no_grad():
76
+ for images, actions in val_loader:
77
+ images, actions = images.to(device), actions.to(device)
78
+ outputs = model(images)
79
+ loss = criterion(outputs, actions)
80
+
81
+ val_loss += loss.item()
82
+ _, predicted = torch.max(outputs, 1)
83
+ val_total += actions.size(0)
84
+ val_correct += (predicted == actions).sum().item()
85
+
86
+ val_acc = 100 * val_correct / val_total
87
+
88
+ scheduler.step()
89
+
90
+ print(f"Epoch [{epoch+1:3d}/{NUM_EPOCHS}] "
91
+ f"Train Loss: {train_loss/len(train_loader):.4f} "
92
+ f"Train Acc: {train_acc:.2f}% | "
93
+ f"Val Loss: {val_loss/len(val_loader):.4f} "
94
+ f"Val Acc: {val_acc:.2f}%")
95
+
96
+ # 保存最佳模型
97
+ if val_acc > best_acc:
98
+ best_acc = val_acc
99
+ torch.save(model.state_dict(), os.path.join(MODEL_DIR, "best_model.pth"))
100
+ print(f" ✅ 保存最佳模型 (准确率: {val_acc:.2f}%)")
101
+
102
+ # 保存最终模型
103
+ torch.save(model.state_dict(), os.path.join(MODEL_DIR, "final_model.pth"))
104
+ print(f"\n🎉 训练完成!最佳验证准确率: {best_acc:.2f}%")
105
+ print(f"模型保存在: {MODEL_DIR}")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ train()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ opencv-python>=4.8.0
4
+ numpy>=1.24.0
5
+ pillow>=10.0.0
6
+ matplotlib>=3.7.0
7
+ jupyter>=1.0.0
tests/test_mapper.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_mapper.py
2
+ from inference.game_controller import GameController
3
+ from inference.action_mapper import ActionMapper
4
+ import time
5
+ import os
6
+
7
+ def test_all_actions():
8
+ """测试所有动作映射"""
9
+ ctrl = GameController()
10
+ mapper = ActionMapper(ctrl)
11
+
12
+ print("=" * 50)
13
+ print("王者荣耀AI控制测试 - 全面测试")
14
+ print("=" * 50)
15
+
16
+ # 1. 测试移动(8个方向)
17
+ print("\n[1/4] 测试移动控制...")
18
+ moves = [
19
+ "move_up", "move_down", "move_left", "move_right",
20
+ "move_upleft", "move_upright", "move_downleft", "move_downright"
21
+ ]
22
+
23
+ for move in moves:
24
+ print(f" 执行: {move}")
25
+ mapper.execute(move)
26
+ time.sleep(0.5) # 每个动作间隔0.5秒
27
+
28
+ # 2. 测试战斗技能
29
+ print("\n[2/4] 测试战斗技能...")
30
+ combat_actions = ["attack", "skill_damage", "skill_control"]
31
+
32
+ for action in combat_actions:
33
+ print(f" 执行: {action}")
34
+ mapper.execute(action)
35
+ time.sleep(0.8)
36
+
37
+ # 3. 测试战术动作
38
+ print("\n[3/4] 测试战术动作...")
39
+ tactical_actions = ["recall", "heal", "summoner", "enhance", "upgrade"]
40
+
41
+ for action in tactical_actions:
42
+ print(f" 执行: {action}")
43
+ mapper.execute(action)
44
+ time.sleep(0.8)
45
+
46
+ # 4. 测试截图功能
47
+ print("\n[4/4] 测试截图功能...")
48
+ try:
49
+ img = ctrl.screenshot()
50
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
51
+ filename = f"screenshot_{timestamp}.png"
52
+ img.save(filename)
53
+ print(f" 截图已保存: {filename}")
54
+
55
+ # 获取屏幕尺寸
56
+ w, h = ctrl.get_screen_size()
57
+ print(f" 屏幕尺寸: {w}x{h}")
58
+
59
+ # 显示截图信息
60
+ print(f" 图片尺寸: {img.size}")
61
+ print(f" 图片模式: {img.mode}")
62
+
63
+ except Exception as e:
64
+ print(f" 截图失败: {e}")
65
+
66
+ print("\n" + "=" * 50)
67
+ print("测试完成!")
68
+ print("=" * 50)
69
+
70
+ def test_single_action():
71
+ """交互式测试单个动作"""
72
+ ctrl = GameController()
73
+ mapper = ActionMapper(ctrl)
74
+
75
+ print("\n=== 交互式测试模式 ===")
76
+ print("可用动作:")
77
+ print(" 移动: up, down, left, right, upleft, upright, downleft, downright")
78
+ print(" 战斗: attack, damage, control")
79
+ print(" 战术: recall, heal, summoner, enhance, upgrade")
80
+ print(" 其他: screenshot, quit")
81
+ print("-" * 40)
82
+
83
+ while True:
84
+ cmd = input("\n请输入动作: ").strip().lower()
85
+
86
+ if cmd == 'quit':
87
+ print("退出测试")
88
+ break
89
+ elif cmd == 'screenshot':
90
+ try:
91
+ img = ctrl.screenshot()
92
+ filename = "manual_screenshot.png"
93
+ img.save(filename)
94
+ print(f"截图已保存: {filename}")
95
+ except Exception as e:
96
+ print(f"截图失败: {e}")
97
+ elif cmd in ['up', 'down', 'left', 'right', 'upleft', 'upright', 'downleft', 'downright']:
98
+ action = f"move_{cmd}"
99
+ print(f"执行: {action}")
100
+ mapper.execute(action)
101
+ elif cmd in ['attack', 'damage', 'control']:
102
+ if cmd == 'damage':
103
+ action = 'skill_damage'
104
+ elif cmd == 'control':
105
+ action = 'skill_control'
106
+ else:
107
+ action = cmd
108
+ print(f"执行: {action}")
109
+ mapper.execute(action)
110
+ elif cmd in ['recall', 'heal', 'summoner', 'enhance', 'upgrade']:
111
+ print(f"执行: {cmd}")
112
+ mapper.execute(cmd)
113
+ else:
114
+ print(f"未知动作: {cmd}")
115
+
116
+ time.sleep(0.3)
117
+
118
+ def test_with_delay():
119
+ """带延迟的循环测试(用于观察)"""
120
+ ctrl = GameController()
121
+ mapper = ActionMapper(ctrl)
122
+
123
+ print("\n=== 自动循环测试模式 ===")
124
+ print("将依次执行所有动作,每个动作间隔1秒")
125
+ print("按 Ctrl+C 停止\n")
126
+
127
+ actions = [
128
+ # 移动
129
+ "move_up", "move_down", "move_left", "move_right",
130
+ "move_upleft", "move_upright", "move_downleft", "move_downright",
131
+ # 战斗
132
+ "attack", "skill_damage", "skill_control",
133
+ # 战术
134
+ "recall", "heal", "summoner", "enhance", "upgrade"
135
+ ]
136
+
137
+ try:
138
+ for i, action in enumerate(actions, 1):
139
+ print(f"[{i}/{len(actions)}] {action}")
140
+ mapper.execute(action)
141
+ time.sleep(1) # 每个动作间隔1秒
142
+
143
+ print("\n循环测试完成")
144
+ except KeyboardInterrupt:
145
+ print("\n用户中断测试")
146
+
147
+ if __name__ == "__main__":
148
+ print("选择测试模式:")
149
+ print("1. 全面测试(一次执行所有动作)")
150
+ print("2. 交互式测试(手动输入动作)")
151
+ print("3. 循环测试(自动循环,可观察)")
152
+
153
+ choice = input("\n请选择 (1/2/3): ").strip()
154
+
155
+ if choice == '1':
156
+ test_all_actions()
157
+ elif choice == '2':
158
+ test_single_action()
159
+ elif choice == '3':
160
+ test_with_delay()
161
+ else:
162
+ print("无效选择,运行全面测试")
163
+ test_all_actions()