Spaces:
Sleeping
Sleeping
Parser: support ASK:/PROPOSE:/Q:/PLAN: prefix forms produced by Qwen3 GRPO
Browse files- inference.py +65 -0
inference.py
CHANGED
|
@@ -112,6 +112,20 @@ def create_client() -> Optional[OpenAI]:
|
|
| 112 |
return None
|
| 113 |
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
|
| 116 |
cleaned = _strip_reasoning(response_text)
|
| 117 |
tool_match = re.search(r"TOOL:\s*(\w+)", cleaned, re.IGNORECASE)
|
|
@@ -135,6 +149,10 @@ def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
|
|
| 135 |
args = _parse_positional_args(tool_name, raw_body.strip())
|
| 136 |
return tool_name, args
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
action_match = re.search(
|
| 139 |
r'Action:\s*(\w+)\((?:(\w+)\s*=\s*["\'](.+?)["\']|([^)]*))\)',
|
| 140 |
cleaned, re.DOTALL,
|
|
@@ -155,6 +173,53 @@ def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
|
|
| 155 |
return None, {}
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
def _strip_reasoning(response_text: str) -> str:
|
| 159 |
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE)
|
| 160 |
cleaned = cleaned.replace("```json", "```")
|
|
|
|
| 112 |
return None
|
| 113 |
|
| 114 |
|
| 115 |
+
_PREFIX_TO_TOOL = {
|
| 116 |
+
"ASK": "ask_question",
|
| 117 |
+
"ASK_QUESTION": "ask_question",
|
| 118 |
+
"QUESTION": "ask_question",
|
| 119 |
+
"Q": "ask_question",
|
| 120 |
+
"PROPOSE": "propose_plan",
|
| 121 |
+
"PROPOSE_PLAN": "propose_plan",
|
| 122 |
+
"PLAN": "propose_plan",
|
| 123 |
+
"INFO": "get_task_info",
|
| 124 |
+
"GET_TASK_INFO": "get_task_info",
|
| 125 |
+
"TASK_INFO": "get_task_info",
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
def parse_tool_call(response_text: str) -> tuple[Optional[str], dict]:
|
| 130 |
cleaned = _strip_reasoning(response_text)
|
| 131 |
tool_match = re.search(r"TOOL:\s*(\w+)", cleaned, re.IGNORECASE)
|
|
|
|
| 149 |
args = _parse_positional_args(tool_name, raw_body.strip())
|
| 150 |
return tool_name, args
|
| 151 |
|
| 152 |
+
prefix_tool, prefix_args = _parse_prefixed_call(cleaned)
|
| 153 |
+
if prefix_tool:
|
| 154 |
+
return prefix_tool, prefix_args
|
| 155 |
+
|
| 156 |
action_match = re.search(
|
| 157 |
r'Action:\s*(\w+)\((?:(\w+)\s*=\s*["\'](.+?)["\']|([^)]*))\)',
|
| 158 |
cleaned, re.DOTALL,
|
|
|
|
| 173 |
return None, {}
|
| 174 |
|
| 175 |
|
| 176 |
+
def _parse_prefixed_call(text: str) -> tuple[Optional[str], dict]:
|
| 177 |
+
"""Handle Qwen3 GRPO outputs like:
|
| 178 |
+
ASK: {"question": "What is the budget?"}
|
| 179 |
+
ASK: What is the budget?
|
| 180 |
+
PROPOSE: {"date": "2024-12-25", ...}
|
| 181 |
+
Q: What is the budget?
|
| 182 |
+
|
| 183 |
+
The 0.6B GRPO checkpoint emits these ~20% of the time. We map the
|
| 184 |
+
prefix to the canonical tool name and parse the rest as either a JSON
|
| 185 |
+
object or a free-form question/plan string.
|
| 186 |
+
"""
|
| 187 |
+
match = re.match(r"^\s*([A-Za-z_]+)\s*:\s*(.*)$", text, flags=re.DOTALL)
|
| 188 |
+
if not match:
|
| 189 |
+
return None, {}
|
| 190 |
+
prefix = match.group(1).upper().replace("-", "_")
|
| 191 |
+
if prefix not in _PREFIX_TO_TOOL:
|
| 192 |
+
return None, {}
|
| 193 |
+
tool_name = _PREFIX_TO_TOOL[prefix]
|
| 194 |
+
rest = match.group(2).strip()
|
| 195 |
+
|
| 196 |
+
if rest.startswith("{"):
|
| 197 |
+
parsed = _load_json_like(rest)
|
| 198 |
+
if isinstance(parsed, dict) and parsed:
|
| 199 |
+
if tool_name == "ask_question":
|
| 200 |
+
question = parsed.get("question") or parsed.get("q") or parsed.get("text")
|
| 201 |
+
if isinstance(question, str):
|
| 202 |
+
return tool_name, {"question": question}
|
| 203 |
+
return tool_name, {"question": json.dumps(parsed)}
|
| 204 |
+
if tool_name == "propose_plan":
|
| 205 |
+
inner = parsed.get("plan") if isinstance(parsed.get("plan"), (dict, str)) else None
|
| 206 |
+
if inner is not None:
|
| 207 |
+
plan_str = inner if isinstance(inner, str) else json.dumps(inner)
|
| 208 |
+
return tool_name, {"plan": plan_str}
|
| 209 |
+
return tool_name, {"plan": json.dumps(parsed)}
|
| 210 |
+
return tool_name, {}
|
| 211 |
+
|
| 212 |
+
if tool_name == "ask_question":
|
| 213 |
+
question = rest.strip().strip('"').strip("'")
|
| 214 |
+
if question:
|
| 215 |
+
return tool_name, {"question": question}
|
| 216 |
+
if tool_name == "propose_plan" and rest:
|
| 217 |
+
return tool_name, {"plan": rest}
|
| 218 |
+
if tool_name == "get_task_info":
|
| 219 |
+
return tool_name, {}
|
| 220 |
+
return None, {}
|
| 221 |
+
|
| 222 |
+
|
| 223 |
def _strip_reasoning(response_text: str) -> str:
|
| 224 |
cleaned = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL | re.IGNORECASE)
|
| 225 |
cleaned = cleaned.replace("```json", "```")
|