from fastapi import FastAPI from pydantic import BaseModel from typing import Any, Dict, List, Optional import json import sys from pathlib import Path import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel ROOT_DIR = Path(__file__).resolve().parent RL_DIR = ROOT_DIR / "RL" for path in (ROOT_DIR, RL_DIR): path_str = str(path) if path_str not in sys.path: sys.path.append(path_str) from RL.battleground_nl_utils import game_state_to_natural_language BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507" ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo" DEFAULT_MAX_NEW_TOKENS = 256 DEFAULT_TEMPERATURE = 0.2 app = FastAPI() tokenizer: Optional[AutoTokenizer] = None model = None device = "cuda" if torch.cuda.is_available() else "cpu" INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. Given the current game state as a JSON object, choose the best full-turn sequence of actions and respond with a single JSON object in this exact format: {"actions":[{"type":"","tavern_index":,"hand_index":,"board_index":,"card_name":}, ...]} Rules: 1. Respond with JSON only. Do not add explanations or any extra text. 2. The top-level object must have exactly one key: "actions". 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of atomic action objects. 4. Use 0-based integers for indices or null when not used. 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". 6. "card_name" must exactly match a card name from the game state when required, otherwise null. Now here is the game state JSON: """ INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. Given the following natural language description of the current game state, choose the best full-turn sequence of actions and respond with a single JSON object in this exact format: {"actions":[{"type":"","tavern_index":,"hand_index":,"board_index":,"card_name":}, ...]} Rules: 1. Respond with JSON only. Do not add explanations or any extra text. 2. The top-level object must have exactly one key: "actions". 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of atomic action objects. 4. Use 0-based integers for indices or null when not used. 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". 6. "card_name" must exactly match a card name from the game state when required, otherwise null. Now here is the description of the game state: """ class GenerateRequest(BaseModel): phase: Optional[str] = None turn: Optional[int] = None state: Dict[str, Any] input_mode: str = "json" # "json" or "nl" max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS temperature: float = DEFAULT_TEMPERATURE def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: state = example.get("state", {}) or {} if input_mode == "nl": nl_state = game_state_to_natural_language(state) prefix = INSTRUCTION_PREFIX_NL state_text = nl_state else: gs = state.get("game_state", {}) or {} phase = example.get("phase", gs.get("phase", "PlayerTurn")) turn = example.get("turn", gs.get("turn_number", 0)) obj = { "task": "battlegrounds_policy_v1", "phase": phase, "turn": turn, "state": state, } state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) prefix = INSTRUCTION_PREFIX return prefix + "\n" + state_text def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]: text = text.strip() start_idx = text.find("{") if start_idx == -1: return None end_idx = text.rfind("}") if end_idx == -1: return None json_str = text[start_idx : end_idx + 1] try: obj = json.loads(json_str) except Exception: return None if not isinstance(obj, dict): return None seq = None if "actions" in obj: if isinstance(obj["actions"], list): seq = obj["actions"] elif isinstance(obj["actions"], dict): seq = [obj["actions"]] elif "action" in obj: if isinstance(obj["action"], list): seq = obj["action"] elif isinstance(obj["action"], dict): seq = [obj["action"]] if seq is None: return None actions: List[Dict[str, Any]] = [] for step in seq: if not isinstance(step, dict): return None actions.append(step) return actions def load_model() -> None: global tokenizer, model if tokenizer is not None and model is not None: return tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.padding_side = "left" if torch.cuda.is_available(): base = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) else: base = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.float32, trust_remote_code=True, ) peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID) if not torch.cuda.is_available(): peft_model.to(device) peft_model.eval() tokenizer = tok model = peft_model @app.on_event("startup") async def _startup_event() -> None: load_model() @app.get("/") def root(): return { "status": "ok", "message": "DeepBattler Battlegrounds Space is running", "base_model": BASE_MODEL_ID, "adapter_model": ADAPTER_MODEL_ID, } @app.post("/generate_actions") def generate_actions(req: GenerateRequest): load_model() example = { "phase": req.phase, "turn": req.turn, "state": req.state, } prompt = build_prompt(example, input_mode=req.input_mode) inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=req.temperature, ) generated_ids = output_ids[0, inputs["input_ids"].shape[1] :] completion = tokenizer.decode(generated_ids, skip_special_tokens=True) actions = parse_actions_from_completion(completion) return { "actions": actions, "raw_completion": completion, }