Spaces:
Sleeping
Sleeping
File size: 5,376 Bytes
8036a6e e796c83 8036a6e e796c83 f46834b e796c83 f46834b e796c83 8036a6e e796c83 f46834b e796c83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.2
app = FastAPI()
tokenizer: Optional[AutoTokenizer] = None
model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose the best full-turn sequence
of actions and respond with a single JSON object in this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the game state JSON:
"""
class GenerateRequest(BaseModel):
phase: Optional[str] = None
turn: Optional[int] = None
state: Dict[str, Any]
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
temperature: float = DEFAULT_TEMPERATURE
def build_prompt(example: Dict[str, Any]) -> str:
"""Build a JSON-mode prompt (the only mode supported by this Space)."""
state = example.get("state", {}) or {}
gs = state.get("game_state", {}) or {}
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
turn = example.get("turn", gs.get("turn_number", 0))
obj = {
"task": "battlegrounds_policy_v1",
"phase": phase,
"turn": turn,
"state": state,
}
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
return INSTRUCTION_PREFIX + "\n" + state_text
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
text = text.strip()
start_idx = text.find("{")
if start_idx == -1:
return None
end_idx = text.rfind("}")
if end_idx == -1:
return None
json_str = text[start_idx : end_idx + 1]
try:
obj = json.loads(json_str)
except Exception:
return None
if not isinstance(obj, dict):
return None
seq = None
if "actions" in obj:
if isinstance(obj["actions"], list):
seq = obj["actions"]
elif isinstance(obj["actions"], dict):
seq = [obj["actions"]]
elif "action" in obj:
if isinstance(obj["action"], list):
seq = obj["action"]
elif isinstance(obj["action"], dict):
seq = [obj["action"]]
if seq is None:
return None
actions: List[Dict[str, Any]] = []
for step in seq:
if not isinstance(step, dict):
return None
actions.append(step)
return actions
def load_model() -> None:
global tokenizer, model
if tokenizer is not None and model is not None:
return
tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
if torch.cuda.is_available():
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
else:
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
trust_remote_code=True,
)
peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID)
if not torch.cuda.is_available():
peft_model.to(device)
peft_model.eval()
tokenizer = tok
model = peft_model
@app.on_event("startup")
async def _startup_event() -> None:
load_model()
@app.get("/")
def root():
return {
"status": "ok",
"message": "DeepBattler Battlegrounds Space is running",
"base_model": BASE_MODEL_ID,
"adapter_model": ADAPTER_MODEL_ID,
}
@app.post("/generate_actions")
def generate_actions(req: GenerateRequest):
load_model()
example = {
"phase": req.phase,
"turn": req.turn,
"state": req.state,
}
prompt = build_prompt(example)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=req.temperature,
)
generated_ids = output_ids[0, inputs["input_ids"].shape[1] :]
completion = tokenizer.decode(generated_ids, skip_special_tokens=True)
actions = parse_actions_from_completion(completion)
return {
"actions": actions,
"raw_completion": completion,
} |