Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import Any, Dict, List, Optional | |
| import json | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| from RL.battleground_nl_utils import game_state_to_natural_language | |
| BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507" | |
| ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo" | |
| DEFAULT_MAX_NEW_TOKENS = 256 | |
| DEFAULT_TEMPERATURE = 0.2 | |
| app = FastAPI() | |
| tokenizer: Optional[AutoTokenizer] = None | |
| model = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. | |
| Given the current game state as a JSON object, choose the best full-turn sequence | |
| of actions and respond with a single JSON object in this exact format: | |
| {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "actions". | |
| 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of | |
| atomic action objects. | |
| 4. Use 0-based integers for indices or null when not used. | |
| 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", | |
| "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 6. "card_name" must exactly match a card name from the game state when required, | |
| otherwise null. | |
| Now here is the game state JSON: | |
| """ | |
| INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. | |
| Given the following natural language description of the current game state, choose | |
| the best full-turn sequence of actions and respond with a single JSON object in | |
| this exact format: | |
| {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "actions". | |
| 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of | |
| atomic action objects. | |
| 4. Use 0-based integers for indices or null when not used. | |
| 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", | |
| "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 6. "card_name" must exactly match a card name from the game state when required, | |
| otherwise null. | |
| Now here is the description of the game state: | |
| """ | |
| class GenerateRequest(BaseModel): | |
| phase: Optional[str] = None | |
| turn: Optional[int] = None | |
| state: Dict[str, Any] | |
| input_mode: str = "json" # "json" or "nl" | |
| max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS | |
| temperature: float = DEFAULT_TEMPERATURE | |
| def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: | |
| state = example.get("state", {}) or {} | |
| if input_mode == "nl": | |
| nl_state = game_state_to_natural_language(state) | |
| prefix = INSTRUCTION_PREFIX_NL | |
| state_text = nl_state | |
| else: | |
| gs = state.get("game_state", {}) or {} | |
| phase = example.get("phase", gs.get("phase", "PlayerTurn")) | |
| turn = example.get("turn", gs.get("turn_number", 0)) | |
| obj = { | |
| "task": "battlegrounds_policy_v1", | |
| "phase": phase, | |
| "turn": turn, | |
| "state": state, | |
| } | |
| state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) | |
| prefix = INSTRUCTION_PREFIX | |
| return prefix + "\n" + state_text | |
| def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]: | |
| text = text.strip() | |
| start_idx = text.find("{") | |
| if start_idx == -1: | |
| return None | |
| end_idx = text.rfind("}") | |
| if end_idx == -1: | |
| return None | |
| json_str = text[start_idx : end_idx + 1] | |
| try: | |
| obj = json.loads(json_str) | |
| except Exception: | |
| return None | |
| if not isinstance(obj, dict): | |
| return None | |
| seq = None | |
| if "actions" in obj: | |
| if isinstance(obj["actions"], list): | |
| seq = obj["actions"] | |
| elif isinstance(obj["actions"], dict): | |
| seq = [obj["actions"]] | |
| elif "action" in obj: | |
| if isinstance(obj["action"], list): | |
| seq = obj["action"] | |
| elif isinstance(obj["action"], dict): | |
| seq = [obj["action"]] | |
| if seq is None: | |
| return None | |
| actions: List[Dict[str, Any]] = [] | |
| for step in seq: | |
| if not isinstance(step, dict): | |
| return None | |
| actions.append(step) | |
| return actions | |
| def load_model() -> None: | |
| global tokenizer, model | |
| if tokenizer is not None and model is not None: | |
| return | |
| tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| tok.padding_side = "left" | |
| if torch.cuda.is_available(): | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| ) | |
| else: | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_ID, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| ) | |
| peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID) | |
| if not torch.cuda.is_available(): | |
| peft_model.to(device) | |
| peft_model.eval() | |
| tokenizer = tok | |
| model = peft_model | |
| async def _startup_event() -> None: | |
| load_model() | |
| def root(): | |
| return { | |
| "status": "ok", | |
| "message": "DeepBattler Battlegrounds Space is running", | |
| "base_model": BASE_MODEL_ID, | |
| "adapter_model": ADAPTER_MODEL_ID, | |
| } | |
| def generate_actions(req: GenerateRequest): | |
| load_model() | |
| example = { | |
| "phase": req.phase, | |
| "turn": req.turn, | |
| "state": req.state, | |
| } | |
| prompt = build_prompt(example, input_mode=req.input_mode) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=req.max_new_tokens, | |
| do_sample=True, | |
| temperature=req.temperature, | |
| ) | |
| generated_ids = output_ids[0, inputs["input_ids"].shape[1] :] | |
| completion = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| actions = parse_actions_from_completion(completion) | |
| return { | |
| "actions": actions, | |
| "raw_completion": completion, | |
| } |