Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # infer_battleground_cloud.py | |
| # | |
| # Cloud-based inference script for a fine-tuned Battlegrounds Qwen model hosted on Hugging Face. | |
| # | |
| # Backends supported: | |
| # 1. Hugging Face Space exposing /generate_actions (preferred for this project) | |
| # 2. Hugging Face Inference Endpoint / Hosted model via InferenceClient | |
| # | |
| # Usage examples: | |
| # PYTHONPATH=. python RL/infer_battleground_cloud.py \ | |
| # --input RL/datasets/game_history_2_flat.json \ | |
| # --output RL/datasets/game_history_2_actions.jsonl \ | |
| # --model-id iteratehack/deepbattler-battleground-gamehistory | |
| # | |
| # or, if you deploy a dedicated Inference Endpoint: | |
| # PYTHONPATH=. python RL/infer_battleground_cloud.py \ | |
| # --input RL/datasets/game_history_2_flat.json \ | |
| # --output RL/datasets/game_history_2_actions.jsonl \ | |
| # --endpoint https://<your-endpoint>.inference.huggingface.cloud | |
| # | |
| # The script expects the same "state" structure and action JSON schema as | |
| # train_battleground_rlaif_gamehistory.py. | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import requests | |
| from huggingface_hub import InferenceClient | |
| from RL.battleground_nl_utils import game_state_to_natural_language | |
| INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. | |
| Given the current game state as a JSON object, choose the best full-turn sequence | |
| of actions and respond with a single JSON object in this exact format: | |
| {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "actions". | |
| 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of | |
| atomic action objects. | |
| 4. Use 0-based integers for indices or null when not used. | |
| 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", | |
| "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 6. "card_name" must exactly match a card name from the game state when required, | |
| otherwise null. | |
| Now here is the game state JSON: | |
| """ | |
| INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. | |
| Given the following natural language description of the current game state, choose | |
| the best full-turn sequence of actions and respond with a single JSON object in | |
| this exact format: | |
| {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "actions". | |
| 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of | |
| atomic action objects. | |
| 4. Use 0-based integers for indices or null when not used. | |
| 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", | |
| "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 6. "card_name" must exactly match a card name from the game state when required, | |
| otherwise null. | |
| Now here is the description of the game state: | |
| """ | |
| def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: | |
| """Build a prompt from a flattened game_history-style example. | |
| This mirrors _build_prompt in train_battleground_rlaif_gamehistory.py so that | |
| the inference distribution matches training. | |
| The example should have: | |
| - phase: string (e.g., "PlayerTurn") | |
| - turn: int | |
| - state: nested dict with keys: game_state, player_hero, resources, board_state | |
| """ | |
| state = example.get("state", {}) or {} | |
| if input_mode == "nl": | |
| nl_state = game_state_to_natural_language(state) | |
| prefix = INSTRUCTION_PREFIX_NL | |
| state_text = nl_state | |
| else: | |
| gs = state.get("game_state", {}) or {} | |
| phase = example.get("phase", gs.get("phase", "PlayerTurn")) | |
| turn = example.get("turn", gs.get("turn_number", 0)) | |
| obj = { | |
| "task": "battlegrounds_policy_v1", | |
| "phase": phase, | |
| "turn": turn, | |
| "state": state, | |
| } | |
| state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) | |
| prefix = INSTRUCTION_PREFIX | |
| return prefix + "\n" + state_text | |
| def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]: | |
| """Parse a model completion into a list of atomic action dicts. | |
| Expected formats (same as training reward parser): | |
| - {"actions": [ {...}, {...}, ... ]} | |
| - {"action": [ {...}, {...}, ... ]} # tolerated fallback | |
| """ | |
| text = text.strip() | |
| start_idx = text.find("{") | |
| if start_idx == -1: | |
| return None | |
| end_idx = text.rfind("}") | |
| if end_idx == -1: | |
| return None | |
| json_str = text[start_idx : end_idx + 1] | |
| try: | |
| obj = json.loads(json_str) | |
| except Exception: | |
| return None | |
| if not isinstance(obj, dict): | |
| return None | |
| seq = None | |
| if "actions" in obj: | |
| if isinstance(obj["actions"], list): | |
| seq = obj["actions"] | |
| elif isinstance(obj["actions"], dict): | |
| seq = [obj["actions"]] | |
| elif "action" in obj: | |
| if isinstance(obj["action"], list): | |
| seq = obj["action"] | |
| elif isinstance(obj["action"], dict): | |
| seq = [obj["action"]] | |
| if seq is None: | |
| return None | |
| actions: List[Dict[str, Any]] = [] | |
| for step in seq: | |
| if not isinstance(step, dict): | |
| return None | |
| actions.append(step) | |
| return actions | |
| def run_inference_via_client( | |
| client: InferenceClient, | |
| examples: List[Dict[str, Any]], | |
| input_mode: str = "json", | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.2, | |
| ) -> List[Dict[str, Any]]: | |
| """Run inference over a list of examples and return enriched records. | |
| Each output row is the original example plus: | |
| - actions: parsed list of atomic action dicts (or None on parse failure) | |
| - raw_completion: raw text returned by the model | |
| """ | |
| results: List[Dict[str, Any]] = [] | |
| for ex in examples: | |
| prompt = build_prompt(ex, input_mode=input_mode) | |
| completion = client.text_generation( | |
| prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| ) | |
| actions = parse_actions_from_completion(completion) | |
| out_row = dict(ex) | |
| out_row["raw_completion"] = completion | |
| out_row["actions"] = actions | |
| results.append(out_row) | |
| return results | |
| def run_inference_via_space( | |
| space_url: str, | |
| examples: List[Dict[str, Any]], | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.2, | |
| timeout: int = 120, | |
| hf_token: Optional[str] = None, | |
| ) -> List[Dict[str, Any]]: | |
| """Call the deployed Space /generate_actions endpoint for each example.""" | |
| base_url = space_url.rstrip("/") | |
| endpoint = f"{base_url}/generate_actions" | |
| headers = {"Content-Type": "application/json"} | |
| if hf_token: | |
| headers["Authorization"] = f"Bearer {hf_token}" | |
| results: List[Dict[str, Any]] = [] | |
| for ex in examples: | |
| payload = { | |
| "phase": ex.get("phase"), | |
| "turn": ex.get("turn"), | |
| "state": ex.get("state", {}), | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| } | |
| resp = requests.post(endpoint, json=payload, headers=headers, timeout=timeout) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| out_row = dict(ex) | |
| out_row["actions"] = data.get("actions") | |
| out_row["raw_completion"] = data.get("raw_completion") | |
| results.append(out_row) | |
| return results | |
| def load_examples(path: str) -> List[Dict[str, Any]]: | |
| p = Path(path) | |
| if not p.exists(): | |
| raise FileNotFoundError(path) | |
| with p.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if not isinstance(data, list): | |
| raise ValueError("Expected input JSON to be a list of examples (flat rows)") | |
| return data | |
| def save_results(path: str, rows: List[Dict[str, Any]]) -> None: | |
| p = Path(path) | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| with p.open("w", encoding="utf-8") as f: | |
| for row in rows: | |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Run cloud inference for Battlegrounds Qwen model via Hugging Face.", | |
| ) | |
| parser.add_argument( | |
| "--input", | |
| required=True, | |
| help="Path to input JSON file (list of flattened game_history rows).", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| required=True, | |
| help="Path to output JSONL file with actions and raw completions.", | |
| ) | |
| parser.add_argument( | |
| "--space-url", | |
| default=None, | |
| help=( | |
| "URL of the Hugging Face Space hosting /generate_actions (e.g. " | |
| "https://iteratehack-deepbattler.hf.space). If provided, the script calls " | |
| "that endpoint instead of the Inference API." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--model-id", | |
| default=None, | |
| help=( | |
| "Hugging Face model repo id (e.g. iteratehack/deepbattler-battleground-gamehistory). " | |
| "Used only if --space-url is omitted." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--endpoint", | |
| default=None, | |
| help=( | |
| "Full URL of a dedicated Inference Endpoint. If provided (and --space-url missing), " | |
| "this takes precedence over --model-id." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--hf-token", | |
| default=None, | |
| help=( | |
| "Hugging Face access token. Needed for private Spaces/models. If omitted, use the token " | |
| "from `huggingface-cli login` or HF_TOKEN env var." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--input-mode", | |
| choices=["json", "nl"], | |
| default="json", | |
| help="Match the input_mode used during training (json or nl).", | |
| ) | |
| parser.add_argument("--max-new-tokens", type=int, default=256) | |
| parser.add_argument("--temperature", type=float, default=0.2) | |
| parser.add_argument( | |
| "--request-timeout", | |
| type=int, | |
| default=120, | |
| help="Timeout (seconds) for HTTP requests when using --space-url", | |
| ) | |
| parser.add_argument( | |
| "--print-results", | |
| action="store_true", | |
| help="Print each output row (JSON) to stdout after inference.", | |
| ) | |
| args = parser.parse_args() | |
| if not any([args.space_url, args.endpoint, args.model_id]): | |
| parser.error("Provide --space-url, --endpoint, or --model-id") | |
| return args | |
| def main() -> None: | |
| args = parse_args() | |
| examples = load_examples(args.input) | |
| if args.space_url: | |
| results = run_inference_via_space( | |
| args.space_url, | |
| examples, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| timeout=args.request_timeout, | |
| hf_token=args.hf_token, | |
| ) | |
| else: | |
| if args.endpoint: | |
| client = InferenceClient(args.endpoint, token=args.hf_token) | |
| else: | |
| client = InferenceClient(args.model_id, token=args.hf_token) | |
| results = run_inference_via_client( | |
| client, | |
| examples, | |
| input_mode=args.input_mode, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| ) | |
| save_results(args.output, results) | |
| print(f"Wrote {len(results)} rows to {args.output}") | |
| if args.print_results: | |
| for row in results: | |
| print(json.dumps(row, ensure_ascii=False)) | |
| if __name__ == "__main__": | |
| main() | |