Spaces:
Paused
Paused
File size: 6,696 Bytes
8036a6e e796c83 8036a6e e796c83 8036a6e e796c83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from RL.battleground_nl_utils import game_state_to_natural_language
BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.2
app = FastAPI()
tokenizer: Optional[AutoTokenizer] = None
model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose the best full-turn sequence
of actions and respond with a single JSON object in this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the game state JSON:
"""
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
Given the following natural language description of the current game state, choose
the best full-turn sequence of actions and respond with a single JSON object in
this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the description of the game state:
"""
class GenerateRequest(BaseModel):
phase: Optional[str] = None
turn: Optional[int] = None
state: Dict[str, Any]
input_mode: str = "json" # "json" or "nl"
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
temperature: float = DEFAULT_TEMPERATURE
def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
state = example.get("state", {}) or {}
if input_mode == "nl":
nl_state = game_state_to_natural_language(state)
prefix = INSTRUCTION_PREFIX_NL
state_text = nl_state
else:
gs = state.get("game_state", {}) or {}
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
turn = example.get("turn", gs.get("turn_number", 0))
obj = {
"task": "battlegrounds_policy_v1",
"phase": phase,
"turn": turn,
"state": state,
}
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
prefix = INSTRUCTION_PREFIX
return prefix + "\n" + state_text
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
text = text.strip()
start_idx = text.find("{")
if start_idx == -1:
return None
end_idx = text.rfind("}")
if end_idx == -1:
return None
json_str = text[start_idx : end_idx + 1]
try:
obj = json.loads(json_str)
except Exception:
return None
if not isinstance(obj, dict):
return None
seq = None
if "actions" in obj:
if isinstance(obj["actions"], list):
seq = obj["actions"]
elif isinstance(obj["actions"], dict):
seq = [obj["actions"]]
elif "action" in obj:
if isinstance(obj["action"], list):
seq = obj["action"]
elif isinstance(obj["action"], dict):
seq = [obj["action"]]
if seq is None:
return None
actions: List[Dict[str, Any]] = []
for step in seq:
if not isinstance(step, dict):
return None
actions.append(step)
return actions
def load_model() -> None:
global tokenizer, model
if tokenizer is not None and model is not None:
return
tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
if torch.cuda.is_available():
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
else:
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
trust_remote_code=True,
)
peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID)
if not torch.cuda.is_available():
peft_model.to(device)
peft_model.eval()
tokenizer = tok
model = peft_model
@app.on_event("startup")
async def _startup_event() -> None:
load_model()
@app.get("/")
def root():
return {
"status": "ok",
"message": "DeepBattler Battlegrounds Space is running",
"base_model": BASE_MODEL_ID,
"adapter_model": ADAPTER_MODEL_ID,
}
@app.post("/generate_actions")
def generate_actions(req: GenerateRequest):
load_model()
example = {
"phase": req.phase,
"turn": req.turn,
"state": req.state,
}
prompt = build_prompt(example, input_mode=req.input_mode)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=req.temperature,
)
generated_ids = output_ids[0, inputs["input_ids"].shape[1] :]
completion = tokenizer.decode(generated_ids, skip_special_tokens=True)
actions = parse_actions_from_completion(completion)
return {
"actions": actions,
"raw_completion": completion,
} |