deepbattler / app.py
lbtwyk
Add HF Space GPU inference for Battlegrounds Qwen
e796c83
raw
history blame
6.7 kB
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from RL.battleground_nl_utils import game_state_to_natural_language
BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.2
app = FastAPI()
tokenizer: Optional[AutoTokenizer] = None
model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose the best full-turn sequence
of actions and respond with a single JSON object in this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the game state JSON:
"""
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
Given the following natural language description of the current game state, choose
the best full-turn sequence of actions and respond with a single JSON object in
this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the description of the game state:
"""
class GenerateRequest(BaseModel):
phase: Optional[str] = None
turn: Optional[int] = None
state: Dict[str, Any]
input_mode: str = "json" # "json" or "nl"
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
temperature: float = DEFAULT_TEMPERATURE
def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
state = example.get("state", {}) or {}
if input_mode == "nl":
nl_state = game_state_to_natural_language(state)
prefix = INSTRUCTION_PREFIX_NL
state_text = nl_state
else:
gs = state.get("game_state", {}) or {}
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
turn = example.get("turn", gs.get("turn_number", 0))
obj = {
"task": "battlegrounds_policy_v1",
"phase": phase,
"turn": turn,
"state": state,
}
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
prefix = INSTRUCTION_PREFIX
return prefix + "\n" + state_text
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
text = text.strip()
start_idx = text.find("{")
if start_idx == -1:
return None
end_idx = text.rfind("}")
if end_idx == -1:
return None
json_str = text[start_idx : end_idx + 1]
try:
obj = json.loads(json_str)
except Exception:
return None
if not isinstance(obj, dict):
return None
seq = None
if "actions" in obj:
if isinstance(obj["actions"], list):
seq = obj["actions"]
elif isinstance(obj["actions"], dict):
seq = [obj["actions"]]
elif "action" in obj:
if isinstance(obj["action"], list):
seq = obj["action"]
elif isinstance(obj["action"], dict):
seq = [obj["action"]]
if seq is None:
return None
actions: List[Dict[str, Any]] = []
for step in seq:
if not isinstance(step, dict):
return None
actions.append(step)
return actions
def load_model() -> None:
global tokenizer, model
if tokenizer is not None and model is not None:
return
tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
if torch.cuda.is_available():
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
else:
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
trust_remote_code=True,
)
peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID)
if not torch.cuda.is_available():
peft_model.to(device)
peft_model.eval()
tokenizer = tok
model = peft_model
@app.on_event("startup")
async def _startup_event() -> None:
load_model()
@app.get("/")
def root():
return {
"status": "ok",
"message": "DeepBattler Battlegrounds Space is running",
"base_model": BASE_MODEL_ID,
"adapter_model": ADAPTER_MODEL_ID,
}
@app.post("/generate_actions")
def generate_actions(req: GenerateRequest):
load_model()
example = {
"phase": req.phase,
"turn": req.turn,
"state": req.state,
}
prompt = build_prompt(example, input_mode=req.input_mode)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=req.temperature,
)
generated_ids = output_ids[0, inputs["input_ids"].shape[1] :]
completion = tokenizer.decode(generated_ids, skip_special_tokens=True)
actions = parse_actions_from_completion(completion)
return {
"actions": actions,
"raw_completion": completion,
}