lbtwyk commited on
Commit
e796c83
·
1 Parent(s): 8036a6e

Add HF Space GPU inference for Battlegrounds Qwen

Browse files
Files changed (3) hide show
  1. RL/infer_battleground_cloud.py +277 -0
  2. app.py +206 -1
  3. requirements.txt +2 -0
RL/infer_battleground_cloud.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # infer_battleground_cloud.py
3
+ #
4
+ # Cloud-based inference script for a fine-tuned Battlegrounds Qwen model hosted on Hugging Face.
5
+ #
6
+ # Usage examples:
7
+ # PYTHONPATH=. python RL/infer_battleground_cloud.py \
8
+ # --input RL/datasets/game_history_2_flat.json \
9
+ # --output RL/datasets/game_history_2_actions.jsonl \
10
+ # --model-id iteratehack/deepbattler-battleground-gamehistory
11
+ #
12
+ # or, if you deploy a dedicated Inference Endpoint:
13
+ # PYTHONPATH=. python RL/infer_battleground_cloud.py \
14
+ # --input RL/datasets/game_history_2_flat.json \
15
+ # --output RL/datasets/game_history_2_actions.jsonl \
16
+ # --endpoint https://<your-endpoint>.inference.huggingface.cloud
17
+ #
18
+ # The script expects the same "state" structure and action JSON schema as
19
+ # train_battleground_rlaif_gamehistory.py.
20
+
21
+ import argparse
22
+ import json
23
+ from pathlib import Path
24
+ from typing import Any, Dict, List, Optional
25
+
26
+ from huggingface_hub import InferenceClient
27
+
28
+ from RL.battleground_nl_utils import game_state_to_natural_language
29
+
30
+
31
+ INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
32
+ Given the current game state as a JSON object, choose the best full-turn sequence
33
+ of actions and respond with a single JSON object in this exact format:
34
+ {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
35
+ Rules:
36
+ 1. Respond with JSON only. Do not add explanations or any extra text.
37
+ 2. The top-level object must have exactly one key: "actions".
38
+ 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
39
+ atomic action objects.
40
+ 4. Use 0-based integers for indices or null when not used.
41
+ 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
42
+ "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
43
+ 6. "card_name" must exactly match a card name from the game state when required,
44
+ otherwise null.
45
+ Now here is the game state JSON:
46
+ """
47
+
48
+ INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
49
+ Given the following natural language description of the current game state, choose
50
+ the best full-turn sequence of actions and respond with a single JSON object in
51
+ this exact format:
52
+ {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
53
+ Rules:
54
+ 1. Respond with JSON only. Do not add explanations or any extra text.
55
+ 2. The top-level object must have exactly one key: "actions".
56
+ 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
57
+ atomic action objects.
58
+ 4. Use 0-based integers for indices or null when not used.
59
+ 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
60
+ "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
61
+ 6. "card_name" must exactly match a card name from the game state when required,
62
+ otherwise null.
63
+ Now here is the description of the game state:
64
+ """
65
+
66
+
67
+ def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
68
+ """Build a prompt from a flattened game_history-style example.
69
+
70
+ This mirrors _build_prompt in train_battleground_rlaif_gamehistory.py so that
71
+ the inference distribution matches training.
72
+
73
+ The example should have:
74
+ - phase: string (e.g., "PlayerTurn")
75
+ - turn: int
76
+ - state: nested dict with keys: game_state, player_hero, resources, board_state
77
+ """
78
+
79
+ state = example.get("state", {}) or {}
80
+
81
+ if input_mode == "nl":
82
+ nl_state = game_state_to_natural_language(state)
83
+ prefix = INSTRUCTION_PREFIX_NL
84
+ state_text = nl_state
85
+ else:
86
+ gs = state.get("game_state", {}) or {}
87
+ phase = example.get("phase", gs.get("phase", "PlayerTurn"))
88
+ turn = example.get("turn", gs.get("turn_number", 0))
89
+ obj = {
90
+ "task": "battlegrounds_policy_v1",
91
+ "phase": phase,
92
+ "turn": turn,
93
+ "state": state,
94
+ }
95
+ state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
96
+ prefix = INSTRUCTION_PREFIX
97
+
98
+ return prefix + "\n" + state_text
99
+
100
+
101
+ def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
102
+ """Parse a model completion into a list of atomic action dicts.
103
+
104
+ Expected formats (same as training reward parser):
105
+ - {"actions": [ {...}, {...}, ... ]}
106
+ - {"action": [ {...}, {...}, ... ]} # tolerated fallback
107
+ """
108
+
109
+ text = text.strip()
110
+ start_idx = text.find("{")
111
+ if start_idx == -1:
112
+ return None
113
+ end_idx = text.rfind("}")
114
+ if end_idx == -1:
115
+ return None
116
+ json_str = text[start_idx : end_idx + 1]
117
+
118
+ try:
119
+ obj = json.loads(json_str)
120
+ except Exception:
121
+ return None
122
+
123
+ if not isinstance(obj, dict):
124
+ return None
125
+
126
+ seq = None
127
+ if "actions" in obj:
128
+ if isinstance(obj["actions"], list):
129
+ seq = obj["actions"]
130
+ elif isinstance(obj["actions"], dict):
131
+ seq = [obj["actions"]]
132
+ elif "action" in obj:
133
+ if isinstance(obj["action"], list):
134
+ seq = obj["action"]
135
+ elif isinstance(obj["action"], dict):
136
+ seq = [obj["action"]]
137
+
138
+ if seq is None:
139
+ return None
140
+
141
+ actions: List[Dict[str, Any]] = []
142
+ for step in seq:
143
+ if not isinstance(step, dict):
144
+ return None
145
+ actions.append(step)
146
+ return actions
147
+
148
+
149
+ def run_inference(
150
+ client: InferenceClient,
151
+ examples: List[Dict[str, Any]],
152
+ input_mode: str = "json",
153
+ max_new_tokens: int = 256,
154
+ temperature: float = 0.2,
155
+ ) -> List[Dict[str, Any]]:
156
+ """Run inference over a list of examples and return enriched records.
157
+
158
+ Each output row is the original example plus:
159
+ - actions: parsed list of atomic action dicts (or None on parse failure)
160
+ - raw_completion: raw text returned by the model
161
+ """
162
+
163
+ results: List[Dict[str, Any]] = []
164
+ for ex in examples:
165
+ prompt = build_prompt(ex, input_mode=input_mode)
166
+ completion = client.text_generation(
167
+ prompt,
168
+ max_new_tokens=max_new_tokens,
169
+ temperature=temperature,
170
+ )
171
+ actions = parse_actions_from_completion(completion)
172
+
173
+ out_row = dict(ex)
174
+ out_row["raw_completion"] = completion
175
+ out_row["actions"] = actions
176
+ results.append(out_row)
177
+
178
+ return results
179
+
180
+
181
+ def load_examples(path: str) -> List[Dict[str, Any]]:
182
+ p = Path(path)
183
+ if not p.exists():
184
+ raise FileNotFoundError(path)
185
+
186
+ with p.open("r", encoding="utf-8") as f:
187
+ data = json.load(f)
188
+
189
+ if not isinstance(data, list):
190
+ raise ValueError("Expected input JSON to be a list of examples (flat rows)")
191
+ return data
192
+
193
+
194
+ def save_results(path: str, rows: List[Dict[str, Any]]) -> None:
195
+ p = Path(path)
196
+ p.parent.mkdir(parents=True, exist_ok=True)
197
+ with p.open("w", encoding="utf-8") as f:
198
+ for row in rows:
199
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
200
+
201
+
202
+ def parse_args() -> argparse.Namespace:
203
+ parser = argparse.ArgumentParser(
204
+ description="Run cloud inference for Battlegrounds Qwen model via Hugging Face.",
205
+ )
206
+ parser.add_argument(
207
+ "--input",
208
+ required=True,
209
+ help="Path to input JSON file (list of flattened game_history rows).",
210
+ )
211
+ parser.add_argument(
212
+ "--output",
213
+ required=True,
214
+ help="Path to output JSONL file with actions and raw completions.",
215
+ )
216
+ parser.add_argument(
217
+ "--model-id",
218
+ default=None,
219
+ help=(
220
+ "Hugging Face model repo id (e.g. iteratehack/deepbattler-battleground-gamehistory). "
221
+ "If provided, serverless / hosted inference will be used."
222
+ ),
223
+ )
224
+ parser.add_argument(
225
+ "--endpoint",
226
+ default=None,
227
+ help=(
228
+ "Full URL of a dedicated Inference Endpoint. If provided, this takes precedence "
229
+ "over --model-id."
230
+ ),
231
+ )
232
+ parser.add_argument(
233
+ "--hf-token",
234
+ default=None,
235
+ help=(
236
+ "Hugging Face access token. If omitted, the token from `huggingface-cli login` "
237
+ "or HF_TOKEN env var will be used."
238
+ ),
239
+ )
240
+ parser.add_argument(
241
+ "--input-mode",
242
+ choices=["json", "nl"],
243
+ default="json",
244
+ help="Match the input_mode used during training (json or nl).",
245
+ )
246
+ parser.add_argument("--max-new-tokens", type=int, default=256)
247
+ parser.add_argument("--temperature", type=float, default=0.2)
248
+
249
+ args = parser.parse_args()
250
+ if not args.model_id and not args.endpoint:
251
+ parser.error("You must provide either --model-id or --endpoint")
252
+
253
+ return args
254
+
255
+
256
+ def main() -> None:
257
+ args = parse_args()
258
+
259
+ if args.endpoint:
260
+ client = InferenceClient(args.endpoint, token=args.hf_token)
261
+ else:
262
+ client = InferenceClient(args.model_id, token=args.hf_token)
263
+
264
+ examples = load_examples(args.input)
265
+ results = run_inference(
266
+ client,
267
+ examples,
268
+ input_mode=args.input_mode,
269
+ max_new_tokens=args.max_new_tokens,
270
+ temperature=args.temperature,
271
+ )
272
+ save_results(args.output, results)
273
+ print(f"Wrote {len(results)} rows to {args.output}")
274
+
275
+
276
+ if __name__ == "__main__":
277
+ main()
app.py CHANGED
@@ -1,7 +1,212 @@
1
  from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
  def root():
7
- return {"status": "ok", "message": "DeepBattler Space is running"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import json
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from peft import PeftModel
9
+
10
+ from RL.battleground_nl_utils import game_state_to_natural_language
11
+
12
+
13
+ BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
14
+ ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
15
+ DEFAULT_MAX_NEW_TOKENS = 256
16
+ DEFAULT_TEMPERATURE = 0.2
17
+
18
 
19
  app = FastAPI()
20
 
21
+ tokenizer: Optional[AutoTokenizer] = None
22
+ model = None
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+
26
+ INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
27
+ Given the current game state as a JSON object, choose the best full-turn sequence
28
+ of actions and respond with a single JSON object in this exact format:
29
+ {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
30
+ Rules:
31
+ 1. Respond with JSON only. Do not add explanations or any extra text.
32
+ 2. The top-level object must have exactly one key: "actions".
33
+ 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
34
+ atomic action objects.
35
+ 4. Use 0-based integers for indices or null when not used.
36
+ 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
37
+ "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
38
+ 6. "card_name" must exactly match a card name from the game state when required,
39
+ otherwise null.
40
+ Now here is the game state JSON:
41
+ """
42
+
43
+ INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
44
+ Given the following natural language description of the current game state, choose
45
+ the best full-turn sequence of actions and respond with a single JSON object in
46
+ this exact format:
47
+ {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
48
+ Rules:
49
+ 1. Respond with JSON only. Do not add explanations or any extra text.
50
+ 2. The top-level object must have exactly one key: "actions".
51
+ 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
52
+ atomic action objects.
53
+ 4. Use 0-based integers for indices or null when not used.
54
+ 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
55
+ "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
56
+ 6. "card_name" must exactly match a card name from the game state when required,
57
+ otherwise null.
58
+ Now here is the description of the game state:
59
+ """
60
+
61
+
62
+ class GenerateRequest(BaseModel):
63
+ phase: Optional[str] = None
64
+ turn: Optional[int] = None
65
+ state: Dict[str, Any]
66
+ input_mode: str = "json" # "json" or "nl"
67
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
68
+ temperature: float = DEFAULT_TEMPERATURE
69
+
70
+
71
+ def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
72
+ state = example.get("state", {}) or {}
73
+
74
+ if input_mode == "nl":
75
+ nl_state = game_state_to_natural_language(state)
76
+ prefix = INSTRUCTION_PREFIX_NL
77
+ state_text = nl_state
78
+ else:
79
+ gs = state.get("game_state", {}) or {}
80
+ phase = example.get("phase", gs.get("phase", "PlayerTurn"))
81
+ turn = example.get("turn", gs.get("turn_number", 0))
82
+ obj = {
83
+ "task": "battlegrounds_policy_v1",
84
+ "phase": phase,
85
+ "turn": turn,
86
+ "state": state,
87
+ }
88
+ state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
89
+ prefix = INSTRUCTION_PREFIX
90
+
91
+ return prefix + "\n" + state_text
92
+
93
+
94
+ def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
95
+ text = text.strip()
96
+ start_idx = text.find("{")
97
+ if start_idx == -1:
98
+ return None
99
+ end_idx = text.rfind("}")
100
+ if end_idx == -1:
101
+ return None
102
+ json_str = text[start_idx : end_idx + 1]
103
+
104
+ try:
105
+ obj = json.loads(json_str)
106
+ except Exception:
107
+ return None
108
+
109
+ if not isinstance(obj, dict):
110
+ return None
111
+
112
+ seq = None
113
+ if "actions" in obj:
114
+ if isinstance(obj["actions"], list):
115
+ seq = obj["actions"]
116
+ elif isinstance(obj["actions"], dict):
117
+ seq = [obj["actions"]]
118
+ elif "action" in obj:
119
+ if isinstance(obj["action"], list):
120
+ seq = obj["action"]
121
+ elif isinstance(obj["action"], dict):
122
+ seq = [obj["action"]]
123
+
124
+ if seq is None:
125
+ return None
126
+
127
+ actions: List[Dict[str, Any]] = []
128
+ for step in seq:
129
+ if not isinstance(step, dict):
130
+ return None
131
+ actions.append(step)
132
+ return actions
133
+
134
+
135
+ def load_model() -> None:
136
+ global tokenizer, model
137
+ if tokenizer is not None and model is not None:
138
+ return
139
+
140
+ tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True)
141
+ if tok.pad_token is None:
142
+ tok.pad_token = tok.eos_token
143
+ tok.padding_side = "left"
144
+
145
+ if torch.cuda.is_available():
146
+ base = AutoModelForCausalLM.from_pretrained(
147
+ BASE_MODEL_ID,
148
+ device_map="auto",
149
+ torch_dtype=torch.bfloat16,
150
+ trust_remote_code=True,
151
+ )
152
+ else:
153
+ base = AutoModelForCausalLM.from_pretrained(
154
+ BASE_MODEL_ID,
155
+ torch_dtype=torch.float32,
156
+ trust_remote_code=True,
157
+ )
158
+
159
+ peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID)
160
+ if not torch.cuda.is_available():
161
+ peft_model.to(device)
162
+ peft_model.eval()
163
+
164
+ tokenizer = tok
165
+ model = peft_model
166
+
167
+
168
+ @app.on_event("startup")
169
+ async def _startup_event() -> None:
170
+ load_model()
171
+
172
+
173
  @app.get("/")
174
  def root():
175
+ return {
176
+ "status": "ok",
177
+ "message": "DeepBattler Battlegrounds Space is running",
178
+ "base_model": BASE_MODEL_ID,
179
+ "adapter_model": ADAPTER_MODEL_ID,
180
+ }
181
+
182
+
183
+ @app.post("/generate_actions")
184
+ def generate_actions(req: GenerateRequest):
185
+ load_model()
186
+
187
+ example = {
188
+ "phase": req.phase,
189
+ "turn": req.turn,
190
+ "state": req.state,
191
+ }
192
+ prompt = build_prompt(example, input_mode=req.input_mode)
193
+
194
+ inputs = tokenizer(prompt, return_tensors="pt")
195
+ inputs = {k: v.to(device) for k, v in inputs.items()}
196
+
197
+ with torch.no_grad():
198
+ output_ids = model.generate(
199
+ **inputs,
200
+ max_new_tokens=req.max_new_tokens,
201
+ do_sample=True,
202
+ temperature=req.temperature,
203
+ )
204
+
205
+ generated_ids = output_ids[0, inputs["input_ids"].shape[1] :]
206
+ completion = tokenizer.decode(generated_ids, skip_special_tokens=True)
207
+ actions = parse_actions_from_completion(completion)
208
+
209
+ return {
210
+ "actions": actions,
211
+ "raw_completion": completion,
212
+ }
requirements.txt CHANGED
@@ -8,3 +8,5 @@ modelscope>=1.15.0
8
  datasets>=2.19.0
9
  peft>=0.11.1
10
  trl>=0.25.1
 
 
 
8
  datasets>=2.19.0
9
  peft>=0.11.1
10
  trl>=0.25.1
11
+ fastapi>=0.115.0
12
+ uvicorn[standard]>=0.30.0