HonorNet_v1 / inference /run_ai.py
clarenceleo's picture
Duplicate from clarenceleo/HonorNet_v1
dd9e164
# inference/run_ai.py
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import time
import numpy as np
from inference.action_mapper import GameController, ActionMapper
from models.king_ai import KingAI
from data.preprocess import ImageProcessor
from config import NUM_ACTIONS, ACTIONS
class AIPlayer:
def __init__(self, model_path):
# 设备
if torch.backends.mps.is_available():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# 模型
self.model = KingAI(num_actions=NUM_ACTIONS).to(self.device)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.eval()
# 控制器
self.ctrl = GameController()
self.mapper = ActionMapper(self.ctrl)
self.processor = ImageProcessor()
# 动作持续
self.current_move = None
self.move_remaining = 0
self.MOVE_DURATION_FRAMES = 10 # 移动持续10帧
self.frame_rate = 10 # 每秒10帧
self.inference_interval = 5 # 每5帧推理一次
self.frame_count = 0
self.inference_count = 0
print(f"✅ AI 加载完成,设备: {self.device}")
print(f"移动持续: {self.MOVE_DURATION_FRAMES} 帧 ({self.MOVE_DURATION_FRAMES/self.frame_rate:.1f}秒)")
def run(self):
print("\n🎮 AI 开始,按 Ctrl+C 停止\n")
try:
while True:
# 每 N 帧推理一次
if self.frame_count % self.inference_interval == 0:
# 截图
screen = self.ctrl.screenshot()
screen_np = np.array(screen)
processed = self.processor.preprocess(screen_np)
tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device)
# 推理
with torch.no_grad():
logits = self.model(tensor)
action_id = torch.argmax(logits, dim=1).item()
action = ACTIONS[action_id]
self.inference_count += 1
# 执行动作
if action.startswith("move_"):
# 移动动作:设置持续帧数
self.current_move = action
self.move_remaining = self.MOVE_DURATION_FRAMES
print(f"[{self.inference_count}] {action} (持续{self.MOVE_DURATION_FRAMES}帧)")
else:
# 攻击/技能:立即执行
self.mapper.execute(action)
print(f"[{self.inference_count}] {action}")
# 每帧都执行当前移动(保持移动)
if self.current_move and self.move_remaining > 0:
self.mapper.execute(self.current_move)
self.move_remaining -= 1
if self.move_remaining == 0:
# 移动结束,停止
self.mapper.execute("move_stop")
self.current_move = None
print(" 移动停止")
self.frame_count += 1
time.sleep(1.0 / self.frame_rate)
except KeyboardInterrupt:
self.mapper.execute("move_stop")
print(f"\n✅ 停止,共推理 {self.inference_count} 次")
if __name__ == "__main__":
model_path = "models/best_model.pth"
if not os.path.exists(model_path):
print(f"❌ 模型不存在: {model_path}")
sys.exit(1)
ai = AIPlayer(model_path)
ai.run()