| |
| import sys |
| import os |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| import torch |
| import time |
| import numpy as np |
| from inference.action_mapper import GameController, ActionMapper |
| from models.king_ai import KingAI |
| from data.preprocess import ImageProcessor |
| from config import NUM_ACTIONS, ACTIONS |
|
|
| class AIPlayer: |
| def __init__(self, model_path): |
| |
| if torch.backends.mps.is_available(): |
| self.device = torch.device("mps") |
| elif torch.cuda.is_available(): |
| self.device = torch.device("cuda") |
| else: |
| self.device = torch.device("cpu") |
| |
| |
| self.model = KingAI(num_actions=NUM_ACTIONS).to(self.device) |
| self.model.load_state_dict(torch.load(model_path, map_location=self.device)) |
| self.model.eval() |
| |
| |
| self.ctrl = GameController() |
| self.mapper = ActionMapper(self.ctrl) |
| self.processor = ImageProcessor() |
| |
| |
| self.current_move = None |
| self.move_remaining = 0 |
| self.MOVE_DURATION_FRAMES = 10 |
| |
| self.frame_rate = 10 |
| self.inference_interval = 5 |
| |
| self.frame_count = 0 |
| self.inference_count = 0 |
| |
| print(f"✅ AI 加载完成,设备: {self.device}") |
| print(f"移动持续: {self.MOVE_DURATION_FRAMES} 帧 ({self.MOVE_DURATION_FRAMES/self.frame_rate:.1f}秒)") |
| |
| def run(self): |
| print("\n🎮 AI 开始,按 Ctrl+C 停止\n") |
| |
| try: |
| while True: |
| |
| if self.frame_count % self.inference_interval == 0: |
| |
| screen = self.ctrl.screenshot() |
| screen_np = np.array(screen) |
| processed = self.processor.preprocess(screen_np) |
| tensor = torch.from_numpy(processed).unsqueeze(0).to(self.device) |
| |
| |
| with torch.no_grad(): |
| logits = self.model(tensor) |
| action_id = torch.argmax(logits, dim=1).item() |
| |
| action = ACTIONS[action_id] |
| self.inference_count += 1 |
| |
| |
| if action.startswith("move_"): |
| |
| self.current_move = action |
| self.move_remaining = self.MOVE_DURATION_FRAMES |
| print(f"[{self.inference_count}] {action} (持续{self.MOVE_DURATION_FRAMES}帧)") |
| else: |
| |
| self.mapper.execute(action) |
| print(f"[{self.inference_count}] {action}") |
| |
| |
| if self.current_move and self.move_remaining > 0: |
| self.mapper.execute(self.current_move) |
| self.move_remaining -= 1 |
| |
| if self.move_remaining == 0: |
| |
| self.mapper.execute("move_stop") |
| self.current_move = None |
| print(" 移动停止") |
| |
| self.frame_count += 1 |
| time.sleep(1.0 / self.frame_rate) |
| |
| except KeyboardInterrupt: |
| self.mapper.execute("move_stop") |
| print(f"\n✅ 停止,共推理 {self.inference_count} 次") |
|
|
|
|
| if __name__ == "__main__": |
| model_path = "models/best_model.pth" |
| if not os.path.exists(model_path): |
| print(f"❌ 模型不存在: {model_path}") |
| sys.exit(1) |
| |
| ai = AIPlayer(model_path) |
| ai.run() |