deepbattler / app.py
lbtwyk
Remove natural language input mode and simplify prompt building
f46834b
raw
history blame
5.38 kB
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.2
app = FastAPI()
tokenizer: Optional[AutoTokenizer] = None
model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose the best full-turn sequence
of actions and respond with a single JSON object in this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the game state JSON:
"""
class GenerateRequest(BaseModel):
phase: Optional[str] = None
turn: Optional[int] = None
state: Dict[str, Any]
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
temperature: float = DEFAULT_TEMPERATURE
def build_prompt(example: Dict[str, Any]) -> str:
"""Build a JSON-mode prompt (the only mode supported by this Space)."""
state = example.get("state", {}) or {}
gs = state.get("game_state", {}) or {}
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
turn = example.get("turn", gs.get("turn_number", 0))
obj = {
"task": "battlegrounds_policy_v1",
"phase": phase,
"turn": turn,
"state": state,
}
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
return INSTRUCTION_PREFIX + "\n" + state_text
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
text = text.strip()
start_idx = text.find("{")
if start_idx == -1:
return None
end_idx = text.rfind("}")
if end_idx == -1:
return None
json_str = text[start_idx : end_idx + 1]
try:
obj = json.loads(json_str)
except Exception:
return None
if not isinstance(obj, dict):
return None
seq = None
if "actions" in obj:
if isinstance(obj["actions"], list):
seq = obj["actions"]
elif isinstance(obj["actions"], dict):
seq = [obj["actions"]]
elif "action" in obj:
if isinstance(obj["action"], list):
seq = obj["action"]
elif isinstance(obj["action"], dict):
seq = [obj["action"]]
if seq is None:
return None
actions: List[Dict[str, Any]] = []
for step in seq:
if not isinstance(step, dict):
return None
actions.append(step)
return actions
def load_model() -> None:
global tokenizer, model
if tokenizer is not None and model is not None:
return
tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
if torch.cuda.is_available():
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
else:
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
trust_remote_code=True,
)
peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID)
if not torch.cuda.is_available():
peft_model.to(device)
peft_model.eval()
tokenizer = tok
model = peft_model
@app.on_event("startup")
async def _startup_event() -> None:
load_model()
@app.get("/")
def root():
return {
"status": "ok",
"message": "DeepBattler Battlegrounds Space is running",
"base_model": BASE_MODEL_ID,
"adapter_model": ADAPTER_MODEL_ID,
}
@app.post("/generate_actions")
def generate_actions(req: GenerateRequest):
load_model()
example = {
"phase": req.phase,
"turn": req.turn,
"state": req.state,
}
prompt = build_prompt(example)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=req.temperature,
)
generated_ids = output_ids[0, inputs["input_ids"].shape[1] :]
completion = tokenizer.decode(generated_ids, skip_special_tokens=True)
actions = parse_actions_from_completion(completion)
return {
"actions": actions,
"raw_completion": completion,
}