agarwalanu3103 commited on
Commit
b8a5922
·
verified ·
1 Parent(s): f251890

Parser: support ASK:/PROPOSE:/Q:/PLAN: prefix forms produced by Qwen3 GRPO

Browse files
Files changed (1) hide show
  1. inference.py +65 -0
inference.py CHANGED
@@ -112,6 +112,20 @@ def create_client() -> Optional[OpenAI]:
112
  return None
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
116
  cleaned = _strip_reasoning(response_text)
117
  tool_match = re.search(r"TOOL:\s*(\w+)", cleaned, re.IGNORECASE)
@@ -135,6 +149,10 @@ def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
135
  args = _parse_positional_args(tool_name, raw_body.strip())
136
  return tool_name, args
137
 
 
 
 
 
138
  action_match = re.search(
139
  r'Action:\s*(\w+)\((?:(\w+)\s*=\s*["\'](.+?)["\']|([^)]*))\)',
140
  cleaned, re.DOTALL,
@@ -155,6 +173,53 @@ def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
155
  return None, {}
156
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def _strip_reasoning(response_text: str) -> str:
159
  cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE)
160
  cleaned = cleaned.replace("```json", "```")
 
112
  return None
113
 
114
 
115
+ _PREFIX_TO_TOOL = {
116
+ "ASK": "ask_question",
117
+ "ASK_QUESTION": "ask_question",
118
+ "QUESTION": "ask_question",
119
+ "Q": "ask_question",
120
+ "PROPOSE": "propose_plan",
121
+ "PROPOSE_PLAN": "propose_plan",
122
+ "PLAN": "propose_plan",
123
+ "INFO": "get_task_info",
124
+ "GET_TASK_INFO": "get_task_info",
125
+ "TASK_INFO": "get_task_info",
126
+ }
127
+
128
+
129
  def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
130
  cleaned = _strip_reasoning(response_text)
131
  tool_match = re.search(r"TOOL:\s*(\w+)", cleaned, re.IGNORECASE)
 
149
  args = _parse_positional_args(tool_name, raw_body.strip())
150
  return tool_name, args
151
 
152
+ prefix_tool, prefix_args = _parse_prefixed_call(cleaned)
153
+ if prefix_tool:
154
+ return prefix_tool, prefix_args
155
+
156
  action_match = re.search(
157
  r'Action:\s*(\w+)\((?:(\w+)\s*=\s*["\'](.+?)["\']|([^)]*))\)',
158
  cleaned, re.DOTALL,
 
173
  return None, {}
174
 
175
 
176
+ def _parse_prefixed_call(text: str) -> tuple[Optional[str], dict]:
177
+ """Handle Qwen3 GRPO outputs like:
178
+ ASK: {"question": "What is the budget?"}
179
+ ASK: What is the budget?
180
+ PROPOSE: {"date": "2024-12-25", ...}
181
+ Q: What is the budget?
182
+
183
+ The 0.6B GRPO checkpoint emits these ~20% of the time. We map the
184
+ prefix to the canonical tool name and parse the rest as either a JSON
185
+ object or a free-form question/plan string.
186
+ """
187
+ match = re.match(r"^\s*([A-Za-z_]+)\s*:\s*(.*)$", text, flags=re.DOTALL)
188
+ if not match:
189
+ return None, {}
190
+ prefix = match.group(1).upper().replace("-", "_")
191
+ if prefix not in _PREFIX_TO_TOOL:
192
+ return None, {}
193
+ tool_name = _PREFIX_TO_TOOL[prefix]
194
+ rest = match.group(2).strip()
195
+
196
+ if rest.startswith("{"):
197
+ parsed = _load_json_like(rest)
198
+ if isinstance(parsed, dict) and parsed:
199
+ if tool_name == "ask_question":
200
+ question = parsed.get("question") or parsed.get("q") or parsed.get("text")
201
+ if isinstance(question, str):
202
+ return tool_name, {"question": question}
203
+ return tool_name, {"question": json.dumps(parsed)}
204
+ if tool_name == "propose_plan":
205
+ inner = parsed.get("plan") if isinstance(parsed.get("plan"), (dict, str)) else None
206
+ if inner is not None:
207
+ plan_str = inner if isinstance(inner, str) else json.dumps(inner)
208
+ return tool_name, {"plan": plan_str}
209
+ return tool_name, {"plan": json.dumps(parsed)}
210
+ return tool_name, {}
211
+
212
+ if tool_name == "ask_question":
213
+ question = rest.strip().strip('"').strip("'")
214
+ if question:
215
+ return tool_name, {"question": question}
216
+ if tool_name == "propose_plan" and rest:
217
+ return tool_name, {"plan": rest}
218
+ if tool_name == "get_task_info":
219
+ return tool_name, {}
220
+ return None, {}
221
+
222
+
223
  def _strip_reasoning(response_text: str) -> str:
224
  cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE)
225
  cleaned = cleaned.replace("```json", "```")