File size: 11,832 Bytes
e796c83
 
 
 
 
fed1ca7
 
 
 
e796c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed1ca7
e796c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed1ca7
e796c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed1ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e796c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed1ca7
 
 
 
 
 
 
 
 
e796c83
 
 
 
 
fed1ca7
e796c83
 
 
 
 
 
fed1ca7
 
e796c83
 
 
 
 
 
fed1ca7
 
e796c83
 
 
 
 
 
 
 
 
 
fed1ca7
 
 
 
 
 
 
 
 
 
 
e796c83
 
fed1ca7
 
e796c83
 
 
 
 
 
 
 
fed1ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e796c83
 
 
fed1ca7
 
 
 
e796c83
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#!/usr/bin/env python
# infer_battleground_cloud.py
#
# Cloud-based inference script for a fine-tuned Battlegrounds Qwen model hosted on Hugging Face.
#
# Backends supported:
#   1. Hugging Face Space exposing /generate_actions (preferred for this project)
#   2. Hugging Face Inference Endpoint / Hosted model via InferenceClient
#
# Usage examples:
#   PYTHONPATH=. python RL/infer_battleground_cloud.py \
#       --input RL/datasets/game_history_2_flat.json \
#       --output RL/datasets/game_history_2_actions.jsonl \
#       --model-id iteratehack/deepbattler-battleground-gamehistory
#
# or, if you deploy a dedicated Inference Endpoint:
#   PYTHONPATH=. python RL/infer_battleground_cloud.py \
#       --input RL/datasets/game_history_2_flat.json \
#       --output RL/datasets/game_history_2_actions.jsonl \
#       --endpoint https://<your-endpoint>.inference.huggingface.cloud
#
# The script expects the same "state" structure and action JSON schema as
# train_battleground_rlaif_gamehistory.py.

import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Optional

import requests
from huggingface_hub import InferenceClient

from RL.battleground_nl_utils import game_state_to_natural_language


INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose the best full-turn sequence
of actions and respond with a single JSON object in this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
   atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
   "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
   otherwise null.
Now here is the game state JSON:
"""

INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
Given the following natural language description of the current game state, choose
the best full-turn sequence of actions and respond with a single JSON object in
this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
   atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
   "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
   otherwise null.
Now here is the description of the game state:
"""


def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
    """Build a prompt from a flattened game_history-style example.

    This mirrors _build_prompt in train_battleground_rlaif_gamehistory.py so that
    the inference distribution matches training.

    The example should have:
      - phase: string (e.g., "PlayerTurn")
      - turn: int
      - state: nested dict with keys: game_state, player_hero, resources, board_state
    """

    state = example.get("state", {}) or {}

    if input_mode == "nl":
        nl_state = game_state_to_natural_language(state)
        prefix = INSTRUCTION_PREFIX_NL
        state_text = nl_state
    else:
        gs = state.get("game_state", {}) or {}
        phase = example.get("phase", gs.get("phase", "PlayerTurn"))
        turn = example.get("turn", gs.get("turn_number", 0))
        obj = {
            "task": "battlegrounds_policy_v1",
            "phase": phase,
            "turn": turn,
            "state": state,
        }
        state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
        prefix = INSTRUCTION_PREFIX

    return prefix + "\n" + state_text


def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
    """Parse a model completion into a list of atomic action dicts.

    Expected formats (same as training reward parser):
      - {"actions": [ {...}, {...}, ... ]}
      - {"action": [ {...}, {...}, ... ]}  # tolerated fallback
    """

    text = text.strip()
    start_idx = text.find("{")
    if start_idx == -1:
        return None
    end_idx = text.rfind("}")
    if end_idx == -1:
        return None
    json_str = text[start_idx : end_idx + 1]

    try:
        obj = json.loads(json_str)
    except Exception:
        return None

    if not isinstance(obj, dict):
        return None

    seq = None
    if "actions" in obj:
        if isinstance(obj["actions"], list):
            seq = obj["actions"]
        elif isinstance(obj["actions"], dict):
            seq = [obj["actions"]]
    elif "action" in obj:
        if isinstance(obj["action"], list):
            seq = obj["action"]
        elif isinstance(obj["action"], dict):
            seq = [obj["action"]]

    if seq is None:
        return None

    actions: List[Dict[str, Any]] = []
    for step in seq:
        if not isinstance(step, dict):
            return None
        actions.append(step)
    return actions


def run_inference_via_client(
    client: InferenceClient,
    examples: List[Dict[str, Any]],
    input_mode: str = "json",
    max_new_tokens: int = 256,
    temperature: float = 0.2,
) -> List[Dict[str, Any]]:
    """Run inference over a list of examples and return enriched records.

    Each output row is the original example plus:
      - actions: parsed list of atomic action dicts (or None on parse failure)
      - raw_completion: raw text returned by the model
    """

    results: List[Dict[str, Any]] = []
    for ex in examples:
        prompt = build_prompt(ex, input_mode=input_mode)
        completion = client.text_generation(
            prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
        )
        actions = parse_actions_from_completion(completion)

        out_row = dict(ex)
        out_row["raw_completion"] = completion
        out_row["actions"] = actions
        results.append(out_row)

    return results


def run_inference_via_space(
    space_url: str,
    examples: List[Dict[str, Any]],
    max_new_tokens: int = 256,
    temperature: float = 0.2,
    timeout: int = 120,
    hf_token: Optional[str] = None,
) -> List[Dict[str, Any]]:
    """Call the deployed Space /generate_actions endpoint for each example."""

    base_url = space_url.rstrip("/")
    endpoint = f"{base_url}/generate_actions"
    headers = {"Content-Type": "application/json"}
    if hf_token:
        headers["Authorization"] = f"Bearer {hf_token}"

    results: List[Dict[str, Any]] = []
    for ex in examples:
        payload = {
            "phase": ex.get("phase"),
            "turn": ex.get("turn"),
            "state": ex.get("state", {}),
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
        }
        resp = requests.post(endpoint, json=payload, headers=headers, timeout=timeout)
        resp.raise_for_status()
        data = resp.json()

        out_row = dict(ex)
        out_row["actions"] = data.get("actions")
        out_row["raw_completion"] = data.get("raw_completion")
        results.append(out_row)

    return results


def load_examples(path: str) -> List[Dict[str, Any]]:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(path)

    with p.open("r", encoding="utf-8") as f:
        data = json.load(f)

    if not isinstance(data, list):
        raise ValueError("Expected input JSON to be a list of examples (flat rows)")
    return data


def save_results(path: str, rows: List[Dict[str, Any]]) -> None:
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    with p.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run cloud inference for Battlegrounds Qwen model via Hugging Face.",
    )
    parser.add_argument(
        "--input",
        required=True,
        help="Path to input JSON file (list of flattened game_history rows).",
    )
    parser.add_argument(
        "--output",
        required=True,
        help="Path to output JSONL file with actions and raw completions.",
    )
    parser.add_argument(
        "--space-url",
        default=None,
        help=(
            "URL of the Hugging Face Space hosting /generate_actions (e.g. "
            "https://iteratehack-deepbattler.hf.space). If provided, the script calls "
            "that endpoint instead of the Inference API."
        ),
    )
    parser.add_argument(
        "--model-id",
        default=None,
        help=(
            "Hugging Face model repo id (e.g. iteratehack/deepbattler-battleground-gamehistory). "
            "Used only if --space-url is omitted."
        ),
    )
    parser.add_argument(
        "--endpoint",
        default=None,
        help=(
            "Full URL of a dedicated Inference Endpoint. If provided (and --space-url missing), "
            "this takes precedence over --model-id."
        ),
    )
    parser.add_argument(
        "--hf-token",
        default=None,
        help=(
            "Hugging Face access token. Needed for private Spaces/models. If omitted, use the token "
            "from `huggingface-cli login` or HF_TOKEN env var."
        ),
    )
    parser.add_argument(
        "--input-mode",
        choices=["json", "nl"],
        default="json",
        help="Match the input_mode used during training (json or nl).",
    )
    parser.add_argument("--max-new-tokens", type=int, default=256)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument(
        "--request-timeout",
        type=int,
        default=120,
        help="Timeout (seconds) for HTTP requests when using --space-url",
    )
    parser.add_argument(
        "--print-results",
        action="store_true",
        help="Print each output row (JSON) to stdout after inference.",
    )

    args = parser.parse_args()
    if not any([args.space_url, args.endpoint, args.model_id]):
        parser.error("Provide --space-url, --endpoint, or --model-id")

    return args


def main() -> None:
    args = parse_args()

    examples = load_examples(args.input)
    if args.space_url:
        results = run_inference_via_space(
            args.space_url,
            examples,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            timeout=args.request_timeout,
            hf_token=args.hf_token,
        )
    else:
        if args.endpoint:
            client = InferenceClient(args.endpoint, token=args.hf_token)
        else:
            client = InferenceClient(args.model_id, token=args.hf_token)

        results = run_inference_via_client(
            client,
            examples,
            input_mode=args.input_mode,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
        )
    save_results(args.output, results)
    print(f"Wrote {len(results)} rows to {args.output}")

    if args.print_results:
        for row in results:
            print(json.dumps(row, ensure_ascii=False))


if __name__ == "__main__":
    main()