Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -15,6 +15,15 @@ from pydantic import BaseModel
|
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from difflib import SequenceMatcher
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# KB services
|
| 19 |
from services.kb_creation import (
|
| 20 |
collection,
|
|
@@ -43,6 +52,25 @@ GEMINI_URL = (
|
|
| 43 |
)
|
| 44 |
os.environ["POSTHOG_DISABLED"] = "true"
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
# ---------------------------------------------------------------------
|
| 47 |
# Minimal server-side cache (used to populate short description if frontend didn’t pass last_issue)
|
| 48 |
# ---------------------------------------------------------------------
|
|
@@ -161,6 +189,136 @@ ERROR_FAMILY_SYNS = {
|
|
| 161 |
),
|
| 162 |
}
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
def _detect_error_families(msg: str) -> list:
|
| 166 |
low = (msg or "").lower()
|
|
@@ -782,6 +940,12 @@ async def chat_with_ai(input_data: ChatInput):
|
|
| 782 |
}
|
| 783 |
except Exception as e:
|
| 784 |
raise HTTPException(status_code=500, detail=safe_str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
|
| 786 |
# ------------------ Hybrid KB search ------------------
|
| 787 |
kb_results = hybrid_search_knowledge_base(input_data.user_message, top_k=10, alpha=0.6, beta=0.4)
|
|
|
|
| 15 |
from dotenv import load_dotenv
|
| 16 |
from difflib import SequenceMatcher
|
| 17 |
|
| 18 |
+
# --- WMS API integration imports ---
|
| 19 |
+
from public_api.public_api_auth import login
|
| 20 |
+
from public_api.public_api_item import get_item
|
| 21 |
+
from public_api.public_api_inventory import get_inventory_holds
|
| 22 |
+
from public_api.public_api_location import get_location_status
|
| 23 |
+
from public_api.utils import flatten_json, extract_fields
|
| 24 |
+
from public_api.field_mapping import TOOL_REGISTRY, ALL_TOOLS
|
| 25 |
+
|
| 26 |
+
|
| 27 |
# KB services
|
| 28 |
from services.kb_creation import (
|
| 29 |
collection,
|
|
|
|
| 52 |
)
|
| 53 |
os.environ["POSTHOG_DISABLED"] = "true"
|
| 54 |
|
| 55 |
+
# --- API WMS session cache ---
|
| 56 |
+
_WMS_SESSION_DATA = None
|
| 57 |
+
_WMS_WAREHOUSE_ID = None
|
| 58 |
+
|
| 59 |
+
def _ensure_wms_session() -> bool:
|
| 60 |
+
"""Login once to WMS and cache session_data + default warehouse."""
|
| 61 |
+
global _WMS_SESSION_DATA, _WMS_WAREHOUSE_ID
|
| 62 |
+
if _WMS_SESSION_DATA and _WMS_WAREHOUSE_ID:
|
| 63 |
+
return True
|
| 64 |
+
try:
|
| 65 |
+
session, warehouse_id, ok = login()
|
| 66 |
+
if ok:
|
| 67 |
+
_WMS_SESSION_DATA = {"wms_auth": session, "user_warehouse_id": warehouse_id}
|
| 68 |
+
_WMS_WAREHOUSE_ID = warehouse_id
|
| 69 |
+
return True
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print("[WMS] login failed:", e)
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
# ---------------------------------------------------------------------
|
| 75 |
# Minimal server-side cache (used to populate short description if frontend didn’t pass last_issue)
|
| 76 |
# ---------------------------------------------------------------------
|
|
|
|
| 189 |
),
|
| 190 |
}
|
| 191 |
|
| 192 |
+
def _try_wms_tool(user_text: str) -> dict | None:
|
| 193 |
+
"""
|
| 194 |
+
Uses Gemini function calling + TOOL_REGISTRY to decide if a WMS tool should be called.
|
| 195 |
+
Returns YOUR existing response dict shape. If no tool match or any error -> returns None.
|
| 196 |
+
"""
|
| 197 |
+
# Safe guards: only attempt if Gemini key exists and WMS login works
|
| 198 |
+
if not GEMINI_API_KEY:
|
| 199 |
+
return None
|
| 200 |
+
if not _ensure_wms_session():
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
headers = {"Content-Type": "application/json"}
|
| 204 |
+
payload = {
|
| 205 |
+
"contents": [{"parts": [{"text": user_text}]}],
|
| 206 |
+
"tools": [{"function_declarations": ALL_TOOLS}],
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
resp = requests.post(GEMINI_URL, headers=headers, json=payload, timeout=25, verify=GEMINI_SSL_VERIFY)
|
| 211 |
+
data = resp.json()
|
| 212 |
+
candidates = data.get("candidates", [])
|
| 213 |
+
part = candidates[0].get("content", {}).get("parts", [{}])[0] if candidates else {}
|
| 214 |
+
if "functionCall" not in part:
|
| 215 |
+
return None
|
| 216 |
+
|
| 217 |
+
fn = part["functionCall"]
|
| 218 |
+
tool_name = fn["name"]
|
| 219 |
+
cfg = TOOL_REGISTRY.get(tool_name)
|
| 220 |
+
if not cfg:
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
# Prepare arguments and warehouse
|
| 224 |
+
args = fn.get("args", {}) or {}
|
| 225 |
+
for k in ("wh_id", "warehouseId", "warehouse_id"): # WMS variations
|
| 226 |
+
args.pop(k, None)
|
| 227 |
+
wh = _WMS_WAREHOUSE_ID
|
| 228 |
+
|
| 229 |
+
# Call underlying function per signature
|
| 230 |
+
if tool_name in ("get_inventory_holds", "get_location_status"):
|
| 231 |
+
raw = cfg"function"
|
| 232 |
+
else:
|
| 233 |
+
target = args.get(cfg["id_param"])
|
| 234 |
+
raw = cfg"function"
|
| 235 |
+
|
| 236 |
+
# Error path -> keep your response shape
|
| 237 |
+
if raw.get("error"):
|
| 238 |
+
return {
|
| 239 |
+
"bot_response": f"⚠️ {raw['error']}",
|
| 240 |
+
"status": "PARTIAL",
|
| 241 |
+
"context_found": False,
|
| 242 |
+
"ask_resolved": False,
|
| 243 |
+
"suggest_incident": True,
|
| 244 |
+
"followup": "Should I raise a ServiceNow ticket?",
|
| 245 |
+
"top_hits": [],
|
| 246 |
+
"sources": [],
|
| 247 |
+
"debug": {"intent": "wms_error", "tool": tool_name},
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
data_list = raw.get(cfg.get("response_key", "data"), [])
|
| 251 |
+
if not data_list:
|
| 252 |
+
return {
|
| 253 |
+
"bot_response": "I searched WMS but couldn't find matching records.",
|
| 254 |
+
"status": "PARTIAL",
|
| 255 |
+
"context_found": False,
|
| 256 |
+
"ask_resolved": False,
|
| 257 |
+
"suggest_incident": True,
|
| 258 |
+
"followup": "Do you want me to raise a ticket or try a different ID?",
|
| 259 |
+
"top_hits": [],
|
| 260 |
+
"sources": [],
|
| 261 |
+
"debug": {"intent": "wms_empty", "tool": tool_name},
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
# Holds / Location → concise lines
|
| 265 |
+
if tool_name in ("get_inventory_holds", "get_location_status"):
|
| 266 |
+
rows, total_qty = [], 0
|
| 267 |
+
for item in data_list:
|
| 268 |
+
flat = flatten_json(item)
|
| 269 |
+
row = extract_fields(flat, cfg["mapping"], cfg.get("formatters", {}))
|
| 270 |
+
rows.append("• " + "; ".join(f"{k}: {v}" for k, v in row.items()))
|
| 271 |
+
total_qty += int(item.get("untqty", 0))
|
| 272 |
+
|
| 273 |
+
if tool_name == "get_inventory_holds":
|
| 274 |
+
msg = f"Found {len(rows)} hold records (Total Qty: {total_qty}).\n" + "\n".join(rows[:10])
|
| 275 |
+
else:
|
| 276 |
+
target_loc = args.get("stoloc") or ""
|
| 277 |
+
msg = f"Location '{target_loc}' status:\n" + ("\n".join(rows[:1]) if rows else "No status returned")
|
| 278 |
+
|
| 279 |
+
return {
|
| 280 |
+
"bot_response": msg,
|
| 281 |
+
"status": "OK",
|
| 282 |
+
"context_found": True,
|
| 283 |
+
"ask_resolved": False,
|
| 284 |
+
"suggest_incident": False,
|
| 285 |
+
"followup": None,
|
| 286 |
+
"top_hits": [],
|
| 287 |
+
"sources": [],
|
| 288 |
+
"debug": {"intent": tool_name, "warehouse_id": wh},
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
# Item → short summary using mapping & requested fields
|
| 292 |
+
flat = flatten_json(data_list[0])
|
| 293 |
+
requested = args.get("fields", [])
|
| 294 |
+
if requested:
|
| 295 |
+
filtered_map = {k: v for k, v in cfg["mapping"].items() if k in requested or k == "Item"}
|
| 296 |
+
order = ["Item"] + [f for f in requested if f != "Item"]
|
| 297 |
+
else:
|
| 298 |
+
summary_keys = ["Item", "Description", "Item type", "Warehouse ID"]
|
| 299 |
+
filtered_map = {k: v for k, v in cfg["mapping"].items() if k in summary_keys}
|
| 300 |
+
order = summary_keys
|
| 301 |
+
|
| 302 |
+
cleaned = extract_fields(flat, filtered_map, cfg.get("formatters", {}), requested_order=order)
|
| 303 |
+
lines = [f"{k}: {cleaned.get(k, 'N/A')}" for k in order]
|
| 304 |
+
msg = "Details:\n" + "\n".join(lines)
|
| 305 |
+
|
| 306 |
+
return {
|
| 307 |
+
"bot_response": msg,
|
| 308 |
+
"status": "OK",
|
| 309 |
+
"context_found": True,
|
| 310 |
+
"ask_resolved": False,
|
| 311 |
+
"suggest_incident": False,
|
| 312 |
+
"followup": None,
|
| 313 |
+
"top_hits": [],
|
| 314 |
+
"sources": [],
|
| 315 |
+
"debug": {"intent": "get_item_data", "warehouse_id": wh},
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
except Exception as e:
|
| 319 |
+
print("[WMS] tool call error:", e)
|
| 320 |
+
return None
|
| 321 |
+
|
| 322 |
|
| 323 |
def _detect_error_families(msg: str) -> list:
|
| 324 |
low = (msg or "").lower()
|
|
|
|
| 940 |
}
|
| 941 |
except Exception as e:
|
| 942 |
raise HTTPException(status_code=500, detail=safe_str(e))
|
| 943 |
+
|
| 944 |
+
# --- Try WMS API tools before KB search ---
|
| 945 |
+
res = _try_wms_tool(input_data.user_message)
|
| 946 |
+
if res is not None:
|
| 947 |
+
return res
|
| 948 |
+
|
| 949 |
|
| 950 |
# ------------------ Hybrid KB search ------------------
|
| 951 |
kb_results = hybrid_search_knowledge_base(input_data.user_message, top_k=10, alpha=0.6, beta=0.4)
|