deepbattler / RL /infer_battleground_cloud.py
lbtwyk
Add HF Space GPU inference for Battlegrounds Qwen
e796c83
raw
history blame
9.3 kB
#!/usr/bin/env python
# infer_battleground_cloud.py
#
# Cloud-based inference script for a fine-tuned Battlegrounds Qwen model hosted on Hugging Face.
#
# Usage examples:
# PYTHONPATH=. python RL/infer_battleground_cloud.py \
# --input RL/datasets/game_history_2_flat.json \
# --output RL/datasets/game_history_2_actions.jsonl \
# --model-id iteratehack/deepbattler-battleground-gamehistory
#
# or, if you deploy a dedicated Inference Endpoint:
# PYTHONPATH=. python RL/infer_battleground_cloud.py \
# --input RL/datasets/game_history_2_flat.json \
# --output RL/datasets/game_history_2_actions.jsonl \
# --endpoint https://<your-endpoint>.inference.huggingface.cloud
#
# The script expects the same "state" structure and action JSON schema as
# train_battleground_rlaif_gamehistory.py.
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
from huggingface_hub import InferenceClient
from RL.battleground_nl_utils import game_state_to_natural_language
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose the best full-turn sequence
of actions and respond with a single JSON object in this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the game state JSON:
"""
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
Given the following natural language description of the current game state, choose
the best full-turn sequence of actions and respond with a single JSON object in
this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the description of the game state:
"""
def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
"""Build a prompt from a flattened game_history-style example.
This mirrors _build_prompt in train_battleground_rlaif_gamehistory.py so that
the inference distribution matches training.
The example should have:
- phase: string (e.g., "PlayerTurn")
- turn: int
- state: nested dict with keys: game_state, player_hero, resources, board_state
"""
state = example.get("state", {}) or {}
if input_mode == "nl":
nl_state = game_state_to_natural_language(state)
prefix = INSTRUCTION_PREFIX_NL
state_text = nl_state
else:
gs = state.get("game_state", {}) or {}
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
turn = example.get("turn", gs.get("turn_number", 0))
obj = {
"task": "battlegrounds_policy_v1",
"phase": phase,
"turn": turn,
"state": state,
}
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
prefix = INSTRUCTION_PREFIX
return prefix + "\n" + state_text
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
"""Parse a model completion into a list of atomic action dicts.
Expected formats (same as training reward parser):
- {"actions": [ {...}, {...}, ... ]}
- {"action": [ {...}, {...}, ... ]} # tolerated fallback
"""
text = text.strip()
start_idx = text.find("{")
if start_idx == -1:
return None
end_idx = text.rfind("}")
if end_idx == -1:
return None
json_str = text[start_idx : end_idx + 1]
try:
obj = json.loads(json_str)
except Exception:
return None
if not isinstance(obj, dict):
return None
seq = None
if "actions" in obj:
if isinstance(obj["actions"], list):
seq = obj["actions"]
elif isinstance(obj["actions"], dict):
seq = [obj["actions"]]
elif "action" in obj:
if isinstance(obj["action"], list):
seq = obj["action"]
elif isinstance(obj["action"], dict):
seq = [obj["action"]]
if seq is None:
return None
actions: List[Dict[str, Any]] = []
for step in seq:
if not isinstance(step, dict):
return None
actions.append(step)
return actions
def run_inference(
client: InferenceClient,
examples: List[Dict[str, Any]],
input_mode: str = "json",
max_new_tokens: int = 256,
temperature: float = 0.2,
) -> List[Dict[str, Any]]:
"""Run inference over a list of examples and return enriched records.
Each output row is the original example plus:
- actions: parsed list of atomic action dicts (or None on parse failure)
- raw_completion: raw text returned by the model
"""
results: List[Dict[str, Any]] = []
for ex in examples:
prompt = build_prompt(ex, input_mode=input_mode)
completion = client.text_generation(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
actions = parse_actions_from_completion(completion)
out_row = dict(ex)
out_row["raw_completion"] = completion
out_row["actions"] = actions
results.append(out_row)
return results
def load_examples(path: str) -> List[Dict[str, Any]]:
p = Path(path)
if not p.exists():
raise FileNotFoundError(path)
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("Expected input JSON to be a list of examples (flat rows)")
return data
def save_results(path: str, rows: List[Dict[str, Any]]) -> None:
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
with p.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run cloud inference for Battlegrounds Qwen model via Hugging Face.",
)
parser.add_argument(
"--input",
required=True,
help="Path to input JSON file (list of flattened game_history rows).",
)
parser.add_argument(
"--output",
required=True,
help="Path to output JSONL file with actions and raw completions.",
)
parser.add_argument(
"--model-id",
default=None,
help=(
"Hugging Face model repo id (e.g. iteratehack/deepbattler-battleground-gamehistory). "
"If provided, serverless / hosted inference will be used."
),
)
parser.add_argument(
"--endpoint",
default=None,
help=(
"Full URL of a dedicated Inference Endpoint. If provided, this takes precedence "
"over --model-id."
),
)
parser.add_argument(
"--hf-token",
default=None,
help=(
"Hugging Face access token. If omitted, the token from `huggingface-cli login` "
"or HF_TOKEN env var will be used."
),
)
parser.add_argument(
"--input-mode",
choices=["json", "nl"],
default="json",
help="Match the input_mode used during training (json or nl).",
)
parser.add_argument("--max-new-tokens", type=int, default=256)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
if not args.model_id and not args.endpoint:
parser.error("You must provide either --model-id or --endpoint")
return args
def main() -> None:
args = parse_args()
if args.endpoint:
client = InferenceClient(args.endpoint, token=args.hf_token)
else:
client = InferenceClient(args.model_id, token=args.hf_token)
examples = load_examples(args.input)
results = run_inference(
client,
examples,
input_mode=args.input_mode,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
)
save_results(args.output, results)
print(f"Wrote {len(results)} rows to {args.output}")
if __name__ == "__main__":
main()