srilakshu012456 commited on
Commit
5f55561
·
verified ·
1 Parent(s): 8da8b0a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +102 -25
main.py CHANGED
@@ -191,11 +191,13 @@ ERROR_FAMILY_SYNS = {
191
 
192
 
193
 
 
194
  def _try_wms_tool(user_text: str) -> dict | None:
195
  """
196
  Uses Gemini function calling + TOOL_REGISTRY to decide if a WMS tool should be called.
197
  Returns YOUR existing response dict shape. If no tool match or any error -> returns None.
198
  """
 
199
  # Guards
200
  if not GEMINI_API_KEY:
201
  return None
@@ -216,21 +218,33 @@ def _try_wms_tool(user_text: str) -> dict | None:
216
  timeout=25,
217
  verify=GEMINI_SSL_VERIFY
218
  )
 
 
219
  data = resp.json()
 
220
  candidates = data.get("candidates", [])
221
- part = candidates[0].get("content", {}).get("parts", [{}])[0] if candidates else {}
222
- if "functionCall" not in part:
 
 
 
 
 
 
 
 
 
223
  return None
224
 
225
- fn = part["functionCall"]
226
- tool_name = fn["name"]
 
227
  cfg = TOOL_REGISTRY.get(tool_name)
228
  if not cfg:
 
229
  return None
230
 
231
- # -------- Prepare arguments and warehouse --------
232
- args = fn.get("args", {}) or {}
233
- # Respect explicit warehouse override if present, else use session default
234
  wh = (
235
  args.pop("warehouse_id", None)
236
  or args.pop("warehouseId", None)
@@ -238,16 +252,60 @@ def _try_wms_tool(user_text: str) -> dict | None:
238
  or _WMS_WAREHOUSE_ID
239
  )
240
 
241
- # -------- Call the underlying function with correct signature --------
242
- # Holds & Location use kwargs (e.g., lodnum/subnum/dtlnum OR stoloc); Item uses single ID
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  if tool_name in ("get_inventory_holds", "get_location_status"):
244
- raw = cfg"function" # <<< invoke function
 
 
 
245
  else:
246
- target = args.get(cfg["id_param"])
247
- raw = cfg"function" # <<< invoke function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- # -------- Error handling --------
250
- if raw.get("error"):
251
  return {
252
  "bot_response": f"⚠️ {raw['error']}",
253
  "status": "PARTIAL",
@@ -258,10 +316,12 @@ def _try_wms_tool(user_text: str) -> dict | None:
258
  "top_hits": [],
259
  "sources": [],
260
  "debug": {"intent": "wms_error", "tool": tool_name},
261
- "source": "ERROR", # optional for frontend branches
262
  }
263
 
264
- data_list = raw.get(cfg.get("response_key", "data"), [])
 
 
265
  if not data_list:
266
  return {
267
  "bot_response": "I searched WMS but couldn't find matching records.",
@@ -273,7 +333,7 @@ def _try_wms_tool(user_text: str) -> dict | None:
273
  "top_hits": [],
274
  "sources": [],
275
  "debug": {"intent": "wms_empty", "tool": tool_name},
276
- "source": "LIVE_API", # still a live call, just no rows
277
  }
278
 
279
  # -------- TABLE: Holds / Location --------
@@ -288,7 +348,13 @@ def _try_wms_tool(user_text: str) -> dict | None:
288
 
289
  if tool_name == "get_inventory_holds":
290
  # Only holds have quantity in most schemas
291
- total_qty = sum(int(it.get("untqty", 0) or 0) for it in data_list)
 
 
 
 
 
 
292
  msg = f"Found {len(rows_for_text)} hold records (Total Qty: {total_qty}).\n" + "\n".join(rows_for_text[:10])
293
  else:
294
  target_loc = args.get("stoloc") or ""
@@ -304,16 +370,15 @@ def _try_wms_tool(user_text: str) -> dict | None:
304
  "top_hits": [],
305
  "sources": [],
306
  "debug": {"intent": tool_name, "warehouse_id": wh},
307
-
308
- # Front-end signals → render table exactly like your friend's code
309
  "source": "LIVE_API",
310
  "type": "table",
311
- "data": data_list, # rows for <TableMessage />
312
  "show_export": True if tool_name == "get_inventory_holds" else False,
313
  }
314
 
315
- # -------- ITEM: summary + single-row table (consistent with "data => table") --------
316
  flat = flatten_json(data_list[0])
 
317
  requested = args.get("fields", [])
318
  if requested:
319
  filtered_map = {k: v for k, v in cfg["mapping"].items() if k in requested or k == "Item"}
@@ -337,18 +402,30 @@ def _try_wms_tool(user_text: str) -> dict | None:
337
  "top_hits": [],
338
  "sources": [],
339
  "debug": {"intent": "get_item_data", "warehouse_id": wh},
340
-
341
  "source": "LIVE_API",
342
  "type": "table",
343
- "data": [data_list[0]], # single-row table for item
344
  "show_export": False,
345
  }
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  except Exception as e:
348
  print("[WMS] tool call error:", e)
349
  return None
350
 
351
-
352
  def _detect_error_families(msg: str) -> list:
353
  low = (msg or "").lower()
354
  low_norm = re.sub(r"[^\w\s]", " ", low)
 
191
 
192
 
193
 
194
+
195
  def _try_wms_tool(user_text: str) -> dict | None:
196
  """
197
  Uses Gemini function calling + TOOL_REGISTRY to decide if a WMS tool should be called.
198
  Returns YOUR existing response dict shape. If no tool match or any error -> returns None.
199
  """
200
+
201
  # Guards
202
  if not GEMINI_API_KEY:
203
  return None
 
218
  timeout=25,
219
  verify=GEMINI_SSL_VERIFY
220
  )
221
+ # Raise if HTTP not 2xx so we handle cleanly
222
+ resp.raise_for_status()
223
  data = resp.json()
224
+
225
  candidates = data.get("candidates", [])
226
+ if not candidates:
227
+ return None
228
+
229
+ # Defensive extraction of the first part
230
+ content = candidates[0].get("content", {})
231
+ parts = content.get("parts", [])
232
+ part = parts[0] if parts else {}
233
+
234
+ # Gemini returns a function call in part["functionCall"] when a tool is selected
235
+ fn_call = part.get("functionCall")
236
+ if not fn_call or "name" not in fn_call:
237
  return None
238
 
239
+ tool_name = fn_call["name"]
240
+ args = fn_call.get("args", {}) or {}
241
+
242
  cfg = TOOL_REGISTRY.get(tool_name)
243
  if not cfg:
244
+ # Tool not registered → bail
245
  return None
246
 
247
+ # -------- Resolve warehouse --------
 
 
248
  wh = (
249
  args.pop("warehouse_id", None)
250
  or args.pop("warehouseId", None)
 
252
  or _WMS_WAREHOUSE_ID
253
  )
254
 
255
+ # -------- Invoke the tool callable correctly --------
256
+ tool_fn = cfg.get("function")
257
+ if not callable(tool_fn):
258
+ # Misconfigured registry
259
+ return {
260
+ "bot_response": f"⚠️ Tool '{tool_name}' is not callable. Check TOOL_REGISTRY['{tool_name}']['function'].",
261
+ "status": "PARTIAL",
262
+ "context_found": False,
263
+ "ask_resolved": False,
264
+ "suggest_incident": True,
265
+ "followup": "Should I raise a ServiceNow ticket?",
266
+ "top_hits": [],
267
+ "sources": [],
268
+ "debug": {"intent": "wms_config_error", "tool": tool_name},
269
+ "source": "ERROR",
270
+ }
271
+
272
+ # Holds & Location use kwargs (e.g., lodnum/subnum/dtlnum OR stoloc);
273
+ # Item uses single ID parameter (e.g., item_id) plus warehouse.
274
  if tool_name in ("get_inventory_holds", "get_location_status"):
275
+ # Pass warehouse as kwarg if present; rest comes from args
276
+ call_kwargs = {"warehouse_id": wh} if wh is not None else {}
277
+ call_kwargs.update(args)
278
+ raw = tool_fn(**call_kwargs)
279
  else:
280
+ id_param = cfg.get("id_param")
281
+ target = args.get(id_param)
282
+ if id_param and target is None:
283
+ return {
284
+ "bot_response": f"Couldn’t find required parameter '{id_param}' in function args.",
285
+ "status": "PARTIAL",
286
+ "context_found": False,
287
+ "ask_resolved": False,
288
+ "suggest_incident": False,
289
+ "followup": f"Can you provide the {id_param}?",
290
+ "top_hits": [],
291
+ "sources": [],
292
+ "debug": {"intent": "wms_missing_id_param", "tool": tool_name},
293
+ "source": "CLIENT",
294
+ }
295
+
296
+ call_kwargs = {"warehouse_id": wh} if wh is not None else {}
297
+ # Include the ID param, plus any other Gemini-parsed args
298
+ if id_param:
299
+ call_kwargs[id_param] = target
300
+ # Keep other non-ID args (filters, etc.)
301
+ for k, v in args.items():
302
+ if k != id_param:
303
+ call_kwargs[k] = v
304
+
305
+ raw = tool_fn(**call_kwargs)
306
 
307
+ # -------- Error handling from tool layer --------
308
+ if isinstance(raw, dict) and raw.get("error"):
309
  return {
310
  "bot_response": f"⚠️ {raw['error']}",
311
  "status": "PARTIAL",
 
316
  "top_hits": [],
317
  "sources": [],
318
  "debug": {"intent": "wms_error", "tool": tool_name},
319
+ "source": "ERROR",
320
  }
321
 
322
+ # -------- Extract data list (default key = 'data') --------
323
+ response_key = cfg.get("response_key", "data")
324
+ data_list = (raw.get(response_key, []) if isinstance(raw, dict) else []) or []
325
  if not data_list:
326
  return {
327
  "bot_response": "I searched WMS but couldn't find matching records.",
 
333
  "top_hits": [],
334
  "sources": [],
335
  "debug": {"intent": "wms_empty", "tool": tool_name},
336
+ "source": "LIVE_API",
337
  }
338
 
339
  # -------- TABLE: Holds / Location --------
 
348
 
349
  if tool_name == "get_inventory_holds":
350
  # Only holds have quantity in most schemas
351
+ def _to_int(x):
352
+ try:
353
+ return int(x or 0)
354
+ except Exception:
355
+ return 0
356
+
357
+ total_qty = sum(_to_int(it.get("untqty", 0)) for it in data_list)
358
  msg = f"Found {len(rows_for_text)} hold records (Total Qty: {total_qty}).\n" + "\n".join(rows_for_text[:10])
359
  else:
360
  target_loc = args.get("stoloc") or ""
 
370
  "top_hits": [],
371
  "sources": [],
372
  "debug": {"intent": tool_name, "warehouse_id": wh},
 
 
373
  "source": "LIVE_API",
374
  "type": "table",
375
+ "data": data_list,
376
  "show_export": True if tool_name == "get_inventory_holds" else False,
377
  }
378
 
379
+ # -------- ITEM: summary + single-row table --------
380
  flat = flatten_json(data_list[0])
381
+
382
  requested = args.get("fields", [])
383
  if requested:
384
  filtered_map = {k: v for k, v in cfg["mapping"].items() if k in requested or k == "Item"}
 
402
  "top_hits": [],
403
  "sources": [],
404
  "debug": {"intent": "get_item_data", "warehouse_id": wh},
 
405
  "source": "LIVE_API",
406
  "type": "table",
407
+ "data": [data_list[0]],
408
  "show_export": False,
409
  }
410
 
411
+ except requests.HTTPError as e:
412
+ # Surface Gemini API issues more clearly
413
+ return {
414
+ "bot_response": f"Gemini API error: {e}",
415
+ "status": "PARTIAL",
416
+ "context_found": False,
417
+ "ask_resolved": False,
418
+ "suggest_incident": True,
419
+ "followup": "Should I raise a ServiceNow ticket?",
420
+ "top_hits": [],
421
+ "sources": [],
422
+ "debug": {"intent": "gemini_http_error"},
423
+ "source": "ERROR",
424
+ }
425
  except Exception as e:
426
  print("[WMS] tool call error:", e)
427
  return None
428
 
 
429
  def _detect_error_families(msg: str) -> list:
430
  low = (msg or "").lower()
431
  low_norm = re.sub(r"[^\w\s]", " ", low)