srilakshu012456 commited on
Commit
4b4e08b
·
verified ·
1 Parent(s): 5ada0c7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +164 -0
main.py CHANGED
@@ -15,6 +15,15 @@ from pydantic import BaseModel
15
  from dotenv import load_dotenv
16
  from difflib import SequenceMatcher
17
 
 
 
 
 
 
 
 
 
 
18
  # KB services
19
  from services.kb_creation import (
20
  collection,
@@ -43,6 +52,25 @@ GEMINI_URL = (
43
  )
44
  os.environ["POSTHOG_DISABLED"] = "true"
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # ---------------------------------------------------------------------
47
  # Minimal server-side cache (used to populate short description if frontend didn’t pass last_issue)
48
  # ---------------------------------------------------------------------
@@ -161,6 +189,136 @@ ERROR_FAMILY_SYNS = {
161
  ),
162
  }
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  def _detect_error_families(msg: str) -> list:
166
  low = (msg or "").lower()
@@ -782,6 +940,12 @@ async def chat_with_ai(input_data: ChatInput):
782
  }
783
  except Exception as e:
784
  raise HTTPException(status_code=500, detail=safe_str(e))
 
 
 
 
 
 
785
 
786
  # ------------------ Hybrid KB search ------------------
787
  kb_results = hybrid_search_knowledge_base(input_data.user_message, top_k=10, alpha=0.6, beta=0.4)
 
15
  from dotenv import load_dotenv
16
  from difflib import SequenceMatcher
17
 
18
+ # --- WMS API integration imports ---
19
+ from public_api.public_api_auth import login
20
+ from public_api.public_api_item import get_item
21
+ from public_api.public_api_inventory import get_inventory_holds
22
+ from public_api.public_api_location import get_location_status
23
+ from public_api.utils import flatten_json, extract_fields
24
+ from public_api.field_mapping import TOOL_REGISTRY, ALL_TOOLS
25
+
26
+
27
  # KB services
28
  from services.kb_creation import (
29
  collection,
 
52
  )
53
  os.environ["POSTHOG_DISABLED"] = "true"
54
 
55
+ # --- API WMS session cache ---
56
+ _WMS_SESSION_DATA = None
57
+ _WMS_WAREHOUSE_ID = None
58
+
59
+ def _ensure_wms_session() -> bool:
60
+ """Login once to WMS and cache session_data + default warehouse."""
61
+ global _WMS_SESSION_DATA, _WMS_WAREHOUSE_ID
62
+ if _WMS_SESSION_DATA and _WMS_WAREHOUSE_ID:
63
+ return True
64
+ try:
65
+ session, warehouse_id, ok = login()
66
+ if ok:
67
+ _WMS_SESSION_DATA = {"wms_auth": session, "user_warehouse_id": warehouse_id}
68
+ _WMS_WAREHOUSE_ID = warehouse_id
69
+ return True
70
+ except Exception as e:
71
+ print("[WMS] login failed:", e)
72
+ return False
73
+
74
  # ---------------------------------------------------------------------
75
  # Minimal server-side cache (used to populate short description if frontend didn’t pass last_issue)
76
  # ---------------------------------------------------------------------
 
189
  ),
190
  }
191
 
192
+ def _try_wms_tool(user_text: str) -> dict | None:
193
+ """
194
+ Uses Gemini function calling + TOOL_REGISTRY to decide if a WMS tool should be called.
195
+ Returns YOUR existing response dict shape. If no tool match or any error -> returns None.
196
+ """
197
+ # Safe guards: only attempt if Gemini key exists and WMS login works
198
+ if not GEMINI_API_KEY:
199
+ return None
200
+ if not _ensure_wms_session():
201
+ return None
202
+
203
+ headers = {"Content-Type": "application/json"}
204
+ payload = {
205
+ "contents": [{"parts": [{"text": user_text}]}],
206
+ "tools": [{"function_declarations": ALL_TOOLS}],
207
+ }
208
+
209
+ try:
210
+ resp = requests.post(GEMINI_URL, headers=headers, json=payload, timeout=25, verify=GEMINI_SSL_VERIFY)
211
+ data = resp.json()
212
+ candidates = data.get("candidates", [])
213
+ part = candidates[0].get("content", {}).get("parts", [{}])[0] if candidates else {}
214
+ if "functionCall" not in part:
215
+ return None
216
+
217
+ fn = part["functionCall"]
218
+ tool_name = fn["name"]
219
+ cfg = TOOL_REGISTRY.get(tool_name)
220
+ if not cfg:
221
+ return None
222
+
223
+ # Prepare arguments and warehouse
224
+ args = fn.get("args", {}) or {}
225
+ for k in ("wh_id", "warehouseId", "warehouse_id"): # WMS variations
226
+ args.pop(k, None)
227
+ wh = _WMS_WAREHOUSE_ID
228
+
229
+ # Call underlying function per signature
230
+ if tool_name in ("get_inventory_holds", "get_location_status"):
231
+ raw = cfg"function"
232
+ else:
233
+ target = args.get(cfg["id_param"])
234
+ raw = cfg"function"
235
+
236
+ # Error path -> keep your response shape
237
+ if raw.get("error"):
238
+ return {
239
+ "bot_response": f"⚠️ {raw['error']}",
240
+ "status": "PARTIAL",
241
+ "context_found": False,
242
+ "ask_resolved": False,
243
+ "suggest_incident": True,
244
+ "followup": "Should I raise a ServiceNow ticket?",
245
+ "top_hits": [],
246
+ "sources": [],
247
+ "debug": {"intent": "wms_error", "tool": tool_name},
248
+ }
249
+
250
+ data_list = raw.get(cfg.get("response_key", "data"), [])
251
+ if not data_list:
252
+ return {
253
+ "bot_response": "I searched WMS but couldn't find matching records.",
254
+ "status": "PARTIAL",
255
+ "context_found": False,
256
+ "ask_resolved": False,
257
+ "suggest_incident": True,
258
+ "followup": "Do you want me to raise a ticket or try a different ID?",
259
+ "top_hits": [],
260
+ "sources": [],
261
+ "debug": {"intent": "wms_empty", "tool": tool_name},
262
+ }
263
+
264
+ # Holds / Location → concise lines
265
+ if tool_name in ("get_inventory_holds", "get_location_status"):
266
+ rows, total_qty = [], 0
267
+ for item in data_list:
268
+ flat = flatten_json(item)
269
+ row = extract_fields(flat, cfg["mapping"], cfg.get("formatters", {}))
270
+ rows.append("• " + "; ".join(f"{k}: {v}" for k, v in row.items()))
271
+ total_qty += int(item.get("untqty", 0))
272
+
273
+ if tool_name == "get_inventory_holds":
274
+ msg = f"Found {len(rows)} hold records (Total Qty: {total_qty}).\n" + "\n".join(rows[:10])
275
+ else:
276
+ target_loc = args.get("stoloc") or ""
277
+ msg = f"Location '{target_loc}' status:\n" + ("\n".join(rows[:1]) if rows else "No status returned")
278
+
279
+ return {
280
+ "bot_response": msg,
281
+ "status": "OK",
282
+ "context_found": True,
283
+ "ask_resolved": False,
284
+ "suggest_incident": False,
285
+ "followup": None,
286
+ "top_hits": [],
287
+ "sources": [],
288
+ "debug": {"intent": tool_name, "warehouse_id": wh},
289
+ }
290
+
291
+ # Item → short summary using mapping & requested fields
292
+ flat = flatten_json(data_list[0])
293
+ requested = args.get("fields", [])
294
+ if requested:
295
+ filtered_map = {k: v for k, v in cfg["mapping"].items() if k in requested or k == "Item"}
296
+ order = ["Item"] + [f for f in requested if f != "Item"]
297
+ else:
298
+ summary_keys = ["Item", "Description", "Item type", "Warehouse ID"]
299
+ filtered_map = {k: v for k, v in cfg["mapping"].items() if k in summary_keys}
300
+ order = summary_keys
301
+
302
+ cleaned = extract_fields(flat, filtered_map, cfg.get("formatters", {}), requested_order=order)
303
+ lines = [f"{k}: {cleaned.get(k, 'N/A')}" for k in order]
304
+ msg = "Details:\n" + "\n".join(lines)
305
+
306
+ return {
307
+ "bot_response": msg,
308
+ "status": "OK",
309
+ "context_found": True,
310
+ "ask_resolved": False,
311
+ "suggest_incident": False,
312
+ "followup": None,
313
+ "top_hits": [],
314
+ "sources": [],
315
+ "debug": {"intent": "get_item_data", "warehouse_id": wh},
316
+ }
317
+
318
+ except Exception as e:
319
+ print("[WMS] tool call error:", e)
320
+ return None
321
+
322
 
323
  def _detect_error_families(msg: str) -> list:
324
  low = (msg or "").lower()
 
940
  }
941
  except Exception as e:
942
  raise HTTPException(status_code=500, detail=safe_str(e))
943
+
944
+ # --- Try WMS API tools before KB search ---
945
+ res = _try_wms_tool(input_data.user_message)
946
+ if res is not None:
947
+ return res
948
+
949
 
950
  # ------------------ Hybrid KB search ------------------
951
  kb_results = hybrid_search_knowledge_base(input_data.user_message, top_k=10, alpha=0.6, beta=0.4)