srilakshu012456 commited on
Commit
e5bc8b1
·
verified ·
1 Parent(s): 072cf7c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +218 -25
main.py CHANGED
@@ -218,26 +218,229 @@ ERROR_FAMILY_SYNS = {
218
  ),
219
  }
220
 
221
-
222
  def _try_wms_tool(user_text: str) -> dict | None:
223
  """
224
- Uses Gemini function calling + TOOL_REGISTRY to decide if a WMS tool should be called.
225
- Returns YOUR existing response dict shape. If no tool match or any error -> returns None.
 
 
226
  """
227
- # Guards
228
- if not GEMINI_API_KEY:
229
- return None
230
  if not _ensure_wms_session():
231
  return None
232
 
233
- # Prepare default warehouse from cached session
234
  session_data = _WMS_SESSION_DATA
235
  default_wh = session_data.get("user_warehouse_id") if session_data else None
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  payload = {
238
  "contents": [{"role": "user", "parts": [{"text": user_text}]}],
239
- "tools": [{"functionDeclarations": ALL_TOOLS}], # from public_api.field_mapping
240
- "toolConfig": {"functionCallingConfig": {"mode": "ANY"}}, # nudge function calling
241
  }
242
 
243
  try:
@@ -263,7 +466,6 @@ def _try_wms_tool(user_text: str) -> dict | None:
263
  if not cfg:
264
  return None
265
 
266
- # Resolve warehouse override, else use session default
267
  wh = (
268
  args.pop("warehouse_id", None)
269
  or args.pop("warehouseId", None)
@@ -271,7 +473,6 @@ def _try_wms_tool(user_text: str) -> dict | None:
271
  or default_wh
272
  )
273
 
274
- # Get callable
275
  tool_fn = cfg.get("function")
276
  if not callable(tool_fn):
277
  return {
@@ -287,12 +488,9 @@ def _try_wms_tool(user_text: str) -> dict | None:
287
  "source": "ERROR",
288
  }
289
 
290
- # ---- Correct invocation according to your real function signatures ----
291
- # get_inventory_holds(session_data, warehouse_id, **kwargs)
292
- # get_location_status(session_data, warehouse_id, stoloc)
293
- # get_item(session_data, warehouse_id, item_number, **kwargs)
294
  if tool_name == "get_inventory_holds":
295
- call_kwargs = {**args} # lodnum / subnum / dtlnum
296
  raw = tool_fn(session_data, wh, **call_kwargs)
297
 
298
  elif tool_name == "get_location_status":
@@ -318,7 +516,7 @@ def _try_wms_tool(user_text: str) -> dict | None:
318
  extra_kwargs = {k: v for k, v in args.items() if k != id_param}
319
  raw = tool_fn(session_data, wh, target, **extra_kwargs)
320
 
321
- # ---- Tool-layer error handling ----
322
  if isinstance(raw, dict) and raw.get("error"):
323
  return {
324
  "bot_response": f"⚠️ {raw['error']}",
@@ -333,7 +531,6 @@ def _try_wms_tool(user_text: str) -> dict | None:
333
  "source": "ERROR",
334
  }
335
 
336
- # ---- Extract rows from response ----
337
  response_key = cfg.get("response_key", "data")
338
  data_list = (raw.get(response_key, []) if isinstance(raw, dict) else []) or []
339
  if not data_list:
@@ -350,7 +547,6 @@ def _try_wms_tool(user_text: str) -> dict | None:
350
  "source": "LIVE_API",
351
  }
352
 
353
- # ---- Table rendering (holds + location) ----
354
  if tool_name in ("get_inventory_holds", "get_location_status"):
355
  rows_for_text, total_qty = [], 0
356
  for item in data_list:
@@ -360,10 +556,8 @@ def _try_wms_tool(user_text: str) -> dict | None:
360
 
361
  if tool_name == "get_inventory_holds":
362
  def _to_int(x):
363
- try:
364
- return int(x or 0)
365
- except Exception:
366
- return 0
367
  total_qty = sum(_to_int(it.get("untqty", 0)) for it in data_list)
368
  msg = f"Found {len(rows_for_text)} hold records (Total Qty: {total_qty}).\n" + "\n".join(rows_for_text[:10])
369
  else:
@@ -386,7 +580,7 @@ def _try_wms_tool(user_text: str) -> dict | None:
386
  "show_export": True if tool_name == "get_inventory_holds" else False,
387
  }
388
 
389
- # ---- Item summary + one-row table ----
390
  flat = flatten_json(data_list[0])
391
  requested = args.get("fields", [])
392
  if requested:
@@ -434,7 +628,6 @@ def _try_wms_tool(user_text: str) -> dict | None:
434
  print("[WMS] tool call error:", e)
435
  return None
436
 
437
-
438
  def _detect_error_families(msg: str) -> list:
439
  low = (msg or "").lower()
440
  low_norm = re.sub(r"[^\w\s]", " ", low)
 
218
  ),
219
  }
220
 
 
221
  def _try_wms_tool(user_text: str) -> dict | None:
222
  """
223
+ WMS tool orchestrator:
224
+ 1) Fast-path: regex-detects item/location/holds queries and calls WMS immediately.
225
+ 2) Fallback: Gemini function-calling (functionDeclarations) if fast-path doesn't trigger.
226
+ Returns your existing bot response dict shape; None when no tool applies.
227
  """
228
+ # --- Guards / session ---
 
 
229
  if not _ensure_wms_session():
230
  return None
231
 
 
232
  session_data = _WMS_SESSION_DATA
233
  default_wh = session_data.get("user_warehouse_id") if session_data else None
234
 
235
+ # ------------------------------------------------------------------
236
+ # 1) FAST-PATH (no Gemini needed): item / location / holds
237
+ # ------------------------------------------------------------------
238
+ import re
239
+ text = (user_text or "").strip()
240
+
241
+ # a) ITEM: "Get item number 100001 ..." / "Show item 100001 ..."
242
+ m_item = re.search(r"\bitem\s+(?:number|id)?\s*[: ]?([A-Za-z0-9\-]+)", text, flags=re.IGNORECASE)
243
+ if m_item:
244
+ item_num = m_item.group(1)
245
+
246
+ # Detect requested field labels (matches labels in your ITEM_MAPPING)
247
+ try:
248
+ item_labels = set(TOOL_REGISTRY["get_item_data"]["mapping"].keys())
249
+ except Exception:
250
+ item_labels = set()
251
+ requested = [lbl for lbl in item_labels if re.search(rf"\b{re.escape(lbl)}\b", text, flags=re.IGNORECASE)]
252
+
253
+ try:
254
+ raw = get_item(session_data, default_wh, item_num) # <-- your real function
255
+ if isinstance(raw, dict) and raw.get("error"):
256
+ return {
257
+ "bot_response": f"⚠️ {raw['error']}",
258
+ "status": "PARTIAL",
259
+ "context_found": False,
260
+ "ask_resolved": False,
261
+ "suggest_incident": True,
262
+ "followup": "Should I raise a ServiceNow ticket?",
263
+ "top_hits": [],
264
+ "sources": [],
265
+ "debug": {"intent": "wms_error_fastpath", "tool": "get_item_data"},
266
+ "source": "ERROR",
267
+ }
268
+
269
+ data_list = raw.get("data", []) or []
270
+ if not data_list:
271
+ return {
272
+ "bot_response": "I searched WMS but couldn't find matching records.",
273
+ "status": "PARTIAL",
274
+ "context_found": False,
275
+ "ask_resolved": False,
276
+ "suggest_incident": True,
277
+ "followup": "Do you want me to raise a ticket or try a different ID?",
278
+ "top_hits": [],
279
+ "sources": [],
280
+ "debug": {"intent": "wms_empty_fastpath", "tool": "get_item_data"},
281
+ "source": "LIVE_API",
282
+ }
283
+
284
+ # Build summary table
285
+ cfg = TOOL_REGISTRY["get_item_data"]
286
+ flat = flatten_json(data_list[0])
287
+ if requested:
288
+ filtered_map = {k: v for k, v in cfg["mapping"].items() if k in requested or k == "Item"}
289
+ order = ["Item"] + [f for f in requested if f != "Item"]
290
+ else:
291
+ summary_keys = ["Item", "Description", "Item type", "Warehouse ID"]
292
+ filtered_map = {k: v for k, v in cfg["mapping"].items() if k in summary_keys}
293
+ order = summary_keys
294
+
295
+ cleaned = extract_fields(flat, filtered_map, cfg.get("formatters", {}), requested_order=order)
296
+ lines = [f"{k}: {cleaned.get(k, 'N/A')}" for k in order]
297
+ msg = "Details:\n" + "\n".join(lines)
298
+
299
+ return {
300
+ "bot_response": msg,
301
+ "status": "OK",
302
+ "context_found": True,
303
+ "ask_resolved": False,
304
+ "suggest_incident": False,
305
+ "followup": None,
306
+ "top_hits": [],
307
+ "sources": [],
308
+ "debug": {"intent": "get_item_data_fastpath", "warehouse_id": default_wh},
309
+ "source": "LIVE_API",
310
+ "type": "table",
311
+ "data": [data_list[0]],
312
+ "show_export": False,
313
+ }
314
+ except Exception as e:
315
+ print("[WMS fastpath:item] error:", e)
316
+ # fall through to Gemini
317
+
318
+ # b) LOCATION: "Check location A1-01-01" / "Status of bin A1-01-01"
319
+ m_loc = re.search(r"\b(?:check|status|loc(?:ation)?)\s+(?:bin\s+)?([A-Za-z0-9\-]+)", text, flags=re.IGNORECASE)
320
+ if m_loc:
321
+ stoloc = m_loc.group(1)
322
+ try:
323
+ raw = get_location_status(session_data, default_wh, stoloc)
324
+ if isinstance(raw, dict) and raw.get("error"):
325
+ return {
326
+ "bot_response": f"⚠️ {raw['error']}",
327
+ "status": "PARTIAL",
328
+ "context_found": False,
329
+ "ask_resolved": False,
330
+ "suggest_incident": True,
331
+ "followup": "Should I raise a ServiceNow ticket?",
332
+ "top_hits": [],
333
+ "sources": [],
334
+ "debug": {"intent": "wms_error_fastpath", "tool": "get_location_status"},
335
+ "source": "ERROR",
336
+ }
337
+
338
+ data_list = raw.get("data", []) or []
339
+ cfg = TOOL_REGISTRY["get_location_status"]
340
+ rows_for_text = []
341
+ for item in data_list:
342
+ flat = flatten_json(item)
343
+ row = extract_fields(flat, cfg["mapping"], cfg.get("formatters", {}))
344
+ rows_for_text.append("• " + "; ".join(f"{k}: {v}" for k, v in row.items()))
345
+ msg = f"Location '{stoloc}' status:\n" + ("\n".join(rows_for_text[:1]) if rows_for_text else "No status returned")
346
+
347
+ return {
348
+ "bot_response": msg,
349
+ "status": "OK",
350
+ "context_found": True,
351
+ "ask_resolved": False,
352
+ "suggest_incident": False,
353
+ "followup": None,
354
+ "top_hits": [],
355
+ "sources": [],
356
+ "debug": {"intent": "get_location_status_fastpath", "warehouse_id": default_wh},
357
+ "source": "LIVE_API",
358
+ "type": "table",
359
+ "data": data_list,
360
+ "show_export": False,
361
+ }
362
+ except Exception as e:
363
+ print("[WMS fastpath:location] error:", e)
364
+ # fall through
365
+
366
+ # c) HOLDS: "Show holds for LPN 123456" / "Show all inventory holds"
367
+ m_lpn = re.search(r"\b(?:lpn|lodnum)\s*[: ]?([A-Za-z0-9\-]+)", text, flags=re.IGNORECASE)
368
+ want_all_holds = re.search(r"\b(all\s+inventory\s+holds|list\s+holds|show\s+holds)\b", text, flags=re.IGNORECASE)
369
+ if m_lpn or want_all_holds:
370
+ kwargs = {}
371
+ if m_lpn:
372
+ kwargs["lodnum"] = m_lpn.group(1)
373
+ try:
374
+ raw = get_inventory_holds(session_data, default_wh, **kwargs)
375
+ if isinstance(raw, dict) and raw.get("error"):
376
+ return {
377
+ "bot_response": f"⚠️ {raw['error']}",
378
+ "status": "PARTIAL",
379
+ "context_found": False,
380
+ "ask_resolved": False,
381
+ "suggest_incident": True,
382
+ "followup": "Should I raise a ServiceNow ticket?",
383
+ "top_hits": [],
384
+ "sources": [],
385
+ "debug": {"intent": "wms_error_fastpath", "tool": "get_inventory_holds"},
386
+ "source": "ERROR",
387
+ }
388
+
389
+ data_list = raw.get("data", []) or []
390
+ if not data_list:
391
+ return {
392
+ "bot_response": "I searched WMS but couldn't find matching records.",
393
+ "status": "PARTIAL",
394
+ "context_found": False,
395
+ "ask_resolved": False,
396
+ "suggest_incident": True,
397
+ "followup": "Do you want me to raise a ticket or try a different ID?",
398
+ "top_hits": [],
399
+ "sources": [],
400
+ "debug": {"intent": "wms_empty_fastpath", "tool": "get_inventory_holds"},
401
+ "source": "LIVE_API",
402
+ }
403
+
404
+ cfg = TOOL_REGISTRY["get_inventory_holds"]
405
+ rows_for_text = []
406
+ def _to_int(x):
407
+ try: return int(x or 0)
408
+ except Exception: return 0
409
+ total_qty = 0
410
+ for item in data_list:
411
+ flat = flatten_json(item)
412
+ row = extract_fields(flat, cfg["mapping"], cfg.get("formatters", {}))
413
+ rows_for_text.append("• " + "; ".join(f"{k}: {v}" for k, v in row.items()))
414
+ total_qty += _to_int(item.get("untqty", 0))
415
+
416
+ msg = f"Found {len(rows_for_text)} hold records (Total Qty: {total_qty}).\n" + "\n".join(rows_for_text[:10])
417
+
418
+ return {
419
+ "bot_response": msg,
420
+ "status": "OK",
421
+ "context_found": True,
422
+ "ask_resolved": False,
423
+ "suggest_incident": False,
424
+ "followup": None,
425
+ "top_hits": [],
426
+ "sources": [],
427
+ "debug": {"intent": "get_inventory_holds_fastpath", "warehouse_id": default_wh},
428
+ "source": "LIVE_API",
429
+ "type": "table",
430
+ "data": data_list,
431
+ "show_export": True,
432
+ }
433
+ except Exception as e:
434
+ print("[WMS fastpath:holds] error:", e)
435
+ # fall through
436
+
437
+ # ------------------------------------------------------------------
438
+ # 2) Gemini function-calling fallback (if fast-path didn't match)
439
+ # ------------------------------------------------------------------
440
  payload = {
441
  "contents": [{"role": "user", "parts": [{"text": user_text}]}],
442
+ "tools": [{"functionDeclarations": ALL_TOOLS}],
443
+ "toolConfig": {"functionCallingConfig": {"mode": "ANY"}},
444
  }
445
 
446
  try:
 
466
  if not cfg:
467
  return None
468
 
 
469
  wh = (
470
  args.pop("warehouse_id", None)
471
  or args.pop("warehouseId", None)
 
473
  or default_wh
474
  )
475
 
 
476
  tool_fn = cfg.get("function")
477
  if not callable(tool_fn):
478
  return {
 
488
  "source": "ERROR",
489
  }
490
 
491
+ # ---- Correct invocation for your real signatures ----
 
 
 
492
  if tool_name == "get_inventory_holds":
493
+ call_kwargs = {**args}
494
  raw = tool_fn(session_data, wh, **call_kwargs)
495
 
496
  elif tool_name == "get_location_status":
 
516
  extra_kwargs = {k: v for k, v in args.items() if k != id_param}
517
  raw = tool_fn(session_data, wh, target, **extra_kwargs)
518
 
519
+ # ---- Common handling ----
520
  if isinstance(raw, dict) and raw.get("error"):
521
  return {
522
  "bot_response": f"⚠️ {raw['error']}",
 
531
  "source": "ERROR",
532
  }
533
 
 
534
  response_key = cfg.get("response_key", "data")
535
  data_list = (raw.get(response_key, []) if isinstance(raw, dict) else []) or []
536
  if not data_list:
 
547
  "source": "LIVE_API",
548
  }
549
 
 
550
  if tool_name in ("get_inventory_holds", "get_location_status"):
551
  rows_for_text, total_qty = [], 0
552
  for item in data_list:
 
556
 
557
  if tool_name == "get_inventory_holds":
558
  def _to_int(x):
559
+ try: return int(x or 0)
560
+ except Exception: return 0
 
 
561
  total_qty = sum(_to_int(it.get("untqty", 0)) for it in data_list)
562
  msg = f"Found {len(rows_for_text)} hold records (Total Qty: {total_qty}).\n" + "\n".join(rows_for_text[:10])
563
  else:
 
580
  "show_export": True if tool_name == "get_inventory_holds" else False,
581
  }
582
 
583
+ # Item summary + one-row table
584
  flat = flatten_json(data_list[0])
585
  requested = args.get("fields", [])
586
  if requested:
 
628
  print("[WMS] tool call error:", e)
629
  return None
630
 
 
631
  def _detect_error_families(msg: str) -> list:
632
  low = (msg or "").lower()
633
  low_norm = re.sub(r"[^\w\s]", " ", low)