lbtwyk commited on
Commit
f46834b
·
1 Parent(s): 2adb71d

Remove natural language input mode and simplify prompt building

Browse files
Files changed (1) hide show
  1. app.py +14 -51
app.py CHANGED
@@ -3,22 +3,10 @@ from pydantic import BaseModel
3
  from typing import Any, Dict, List, Optional
4
 
5
  import json
6
- import sys
7
- from pathlib import Path
8
-
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
  from peft import PeftModel
12
 
13
- ROOT_DIR = Path(__file__).resolve().parent
14
- RL_DIR = ROOT_DIR / "RL"
15
- for path in (ROOT_DIR, RL_DIR):
16
- path_str = str(path)
17
- if path_str not in sys.path:
18
- sys.path.append(path_str)
19
-
20
- from RL.battleground_nl_utils import game_state_to_natural_language
21
-
22
 
23
  BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
24
  ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
@@ -50,55 +38,30 @@ Rules:
50
  Now here is the game state JSON:
51
  """
52
 
53
- INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
54
- Given the following natural language description of the current game state, choose
55
- the best full-turn sequence of actions and respond with a single JSON object in
56
- this exact format:
57
- {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
58
- Rules:
59
- 1. Respond with JSON only. Do not add explanations or any extra text.
60
- 2. The top-level object must have exactly one key: "actions".
61
- 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
62
- atomic action objects.
63
- 4. Use 0-based integers for indices or null when not used.
64
- 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
65
- "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
66
- 6. "card_name" must exactly match a card name from the game state when required,
67
- otherwise null.
68
- Now here is the description of the game state:
69
- """
70
 
71
 
72
  class GenerateRequest(BaseModel):
73
  phase: Optional[str] = None
74
  turn: Optional[int] = None
75
  state: Dict[str, Any]
76
- input_mode: str = "json" # "json" or "nl"
77
  max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
78
  temperature: float = DEFAULT_TEMPERATURE
79
 
80
 
81
- def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
 
82
  state = example.get("state", {}) or {}
83
-
84
- if input_mode == "nl":
85
- nl_state = game_state_to_natural_language(state)
86
- prefix = INSTRUCTION_PREFIX_NL
87
- state_text = nl_state
88
- else:
89
- gs = state.get("game_state", {}) or {}
90
- phase = example.get("phase", gs.get("phase", "PlayerTurn"))
91
- turn = example.get("turn", gs.get("turn_number", 0))
92
- obj = {
93
- "task": "battlegrounds_policy_v1",
94
- "phase": phase,
95
- "turn": turn,
96
- "state": state,
97
- }
98
- state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
99
- prefix = INSTRUCTION_PREFIX
100
-
101
- return prefix + "\n" + state_text
102
 
103
 
104
  def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
@@ -199,7 +162,7 @@ def generate_actions(req: GenerateRequest):
199
  "turn": req.turn,
200
  "state": req.state,
201
  }
202
- prompt = build_prompt(example, input_mode=req.input_mode)
203
 
204
  inputs = tokenizer(prompt, return_tensors="pt")
205
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
3
  from typing import Any, Dict, List, Optional
4
 
5
  import json
 
 
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from peft import PeftModel
9
 
 
 
 
 
 
 
 
 
 
10
 
11
  BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
12
  ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
 
38
  Now here is the game state JSON:
39
  """
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  class GenerateRequest(BaseModel):
44
  phase: Optional[str] = None
45
  turn: Optional[int] = None
46
  state: Dict[str, Any]
 
47
  max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
48
  temperature: float = DEFAULT_TEMPERATURE
49
 
50
 
51
+ def build_prompt(example: Dict[str, Any]) -> str:
52
+ """Build a JSON-mode prompt (the only mode supported by this Space)."""
53
  state = example.get("state", {}) or {}
54
+ gs = state.get("game_state", {}) or {}
55
+ phase = example.get("phase", gs.get("phase", "PlayerTurn"))
56
+ turn = example.get("turn", gs.get("turn_number", 0))
57
+ obj = {
58
+ "task": "battlegrounds_policy_v1",
59
+ "phase": phase,
60
+ "turn": turn,
61
+ "state": state,
62
+ }
63
+ state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
64
+ return INSTRUCTION_PREFIX + "\n" + state_text
 
 
 
 
 
 
 
 
65
 
66
 
67
  def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
 
162
  "turn": req.turn,
163
  "state": req.state,
164
  }
165
+ prompt = build_prompt(example)
166
 
167
  inputs = tokenizer(prompt, return_tensors="pt")
168
  inputs = {k: v.to(device) for k, v in inputs.items()}