Spaces:
Runtime error
Runtime error
PPP commited on
Commit ·
da29cc6
1
Parent(s): 07f6446
fix(nlu): relabel move-to-shop purchase requests as trade
Browse files- nlu_engine.py +48 -19
nlu_engine.py
CHANGED
|
@@ -133,15 +133,17 @@ class NLUEngine:
|
|
| 133 |
result = self._llm_parse(user_input)
|
| 134 |
|
| 135 |
# 如果 LLM 解析失败,使用关键词降级
|
| 136 |
-
if result is None:
|
| 137 |
-
logger.warning("LLM 解析失败,使用关键词降级")
|
| 138 |
-
result = self._keyword_fallback(user_input)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def _llm_parse(self, user_input: str) -> Optional[dict]:
|
| 147 |
"""
|
|
@@ -239,10 +241,10 @@ class NLUEngine:
|
|
| 239 |
"parser_source": "keyword_fallback",
|
| 240 |
}
|
| 241 |
|
| 242 |
-
def _extract_target_from_text(self, text: str) -> Optional[str]:
|
| 243 |
-
"""
|
| 244 |
-
从文本中提取可能的目标对象。
|
| 245 |
-
尝试匹配当前场景中的 NPC、物品、地点名称。
|
| 246 |
"""
|
| 247 |
# 检查 NPC 名称
|
| 248 |
for npc_name in self.game_state.world.npcs:
|
|
@@ -265,12 +267,39 @@ class NLUEngine:
|
|
| 265 |
for item_name in self.game_state.world.item_registry:
|
| 266 |
if item_name in text:
|
| 267 |
return item_name
|
| 268 |
-
|
| 269 |
-
return None
|
| 270 |
-
|
| 271 |
-
def
|
| 272 |
-
"""
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
return (
|
| 275 |
f"场景: {gs.world.current_scene}\n"
|
| 276 |
f"时间: 第{gs.world.day_count}天 {gs.world.time_of_day}\n"
|
|
|
|
| 133 |
result = self._llm_parse(user_input)
|
| 134 |
|
| 135 |
# 如果 LLM 解析失败,使用关键词降级
|
| 136 |
+
if result is None:
|
| 137 |
+
logger.warning("LLM 解析失败,使用关键词降级")
|
| 138 |
+
result = self._keyword_fallback(user_input)
|
| 139 |
+
|
| 140 |
+
result = self._apply_intent_postprocessing(result, user_input)
|
| 141 |
+
|
| 142 |
+
# 附加原始输入
|
| 143 |
+
result["raw_input"] = user_input
|
| 144 |
+
|
| 145 |
+
logger.info(f"NLU 解析结果: {result}")
|
| 146 |
+
return result
|
| 147 |
|
| 148 |
def _llm_parse(self, user_input: str) -> Optional[dict]:
|
| 149 |
"""
|
|
|
|
| 241 |
"parser_source": "keyword_fallback",
|
| 242 |
}
|
| 243 |
|
| 244 |
+
def _extract_target_from_text(self, text: str) -> Optional[str]:
|
| 245 |
+
"""
|
| 246 |
+
从文本中提取可能的目标对象。
|
| 247 |
+
尝试匹配当前场景中的 NPC、物品、地点名称。
|
| 248 |
"""
|
| 249 |
# 检查 NPC 名称
|
| 250 |
for npc_name in self.game_state.world.npcs:
|
|
|
|
| 267 |
for item_name in self.game_state.world.item_registry:
|
| 268 |
if item_name in text:
|
| 269 |
return item_name
|
| 270 |
+
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
def _apply_intent_postprocessing(self, result: dict, user_input: str) -> dict:
|
| 274 |
+
"""Apply narrow intent corrections for high-confidence mixed phrases."""
|
| 275 |
+
normalized = dict(result)
|
| 276 |
+
intent = str(normalized.get("intent", "")).upper()
|
| 277 |
+
if intent == "MOVE" and self._looks_like_trade_request(user_input, normalized.get("target")):
|
| 278 |
+
normalized["intent"] = "TRADE"
|
| 279 |
+
normalized["intent_correction"] = "move_to_trade_for_shop_request"
|
| 280 |
+
return normalized
|
| 281 |
+
|
| 282 |
+
def _looks_like_trade_request(self, user_input: str, target: Optional[str]) -> bool:
|
| 283 |
+
trade_pattern = r"买|卖|交易|购买|出售|看看有什么卖的|买点"
|
| 284 |
+
if not re.search(trade_pattern, user_input):
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
target_text = str(target or "")
|
| 288 |
+
if target_text:
|
| 289 |
+
npc = self.game_state.world.npcs.get(target_text)
|
| 290 |
+
if npc and npc.can_trade:
|
| 291 |
+
return True
|
| 292 |
+
|
| 293 |
+
location = self.game_state.world.locations.get(target_text)
|
| 294 |
+
if location and location.shop_available:
|
| 295 |
+
return True
|
| 296 |
+
|
| 297 |
+
shop_hint_pattern = r"商店|杂货铺|旅店|铁匠铺"
|
| 298 |
+
return bool(re.search(shop_hint_pattern, user_input))
|
| 299 |
+
|
| 300 |
+
def _build_context(self) -> str:
|
| 301 |
+
"""构建当前场景的简要上下文描述"""
|
| 302 |
+
gs = self.game_state
|
| 303 |
return (
|
| 304 |
f"场景: {gs.world.current_scene}\n"
|
| 305 |
f"时间: 第{gs.world.day_count}天 {gs.world.time_of_day}\n"
|