| """SQLite catalog access — FTS search + commodity lookup.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import re |
| import sqlite3 |
| from pathlib import Path |
| from typing import Any |
|
|
| |
| _STOPWORDS: frozenset[str] = frozenset( |
| { |
| "the", |
| "a", |
| "an", |
| "for", |
| "and", |
| "or", |
| "to", |
| "of", |
| "in", |
| "on", |
| "at", |
| "by", |
| "with", |
| "from", |
| "as", |
| "is", |
| "are", |
| "was", |
| "were", |
| "be", |
| "been", |
| "being", |
| "have", |
| "has", |
| "had", |
| "do", |
| "does", |
| "did", |
| "will", |
| "would", |
| "could", |
| "should", |
| "may", |
| "might", |
| "must", |
| "can", |
| "i", |
| "you", |
| "we", |
| "they", |
| "it", |
| "its", |
| "this", |
| "that", |
| "these", |
| "those", |
| "please", |
| "need", |
| "want", |
| "looking", |
| "find", |
| "help", |
| "me", |
| "my", |
| "our", |
| "your", |
| "some", |
| "any", |
| "all", |
| "not", |
| "no", |
| "yes", |
| "get", |
| "give", |
| "show", |
| "tell", |
| "about", |
| "into", |
| "over", |
| "also", |
| "just", |
| "only", |
| "even", |
| "such", |
| "than", |
| "then", |
| "there", |
| "when", |
| "where", |
| "which", |
| "who", |
| "how", |
| "why", |
| "what", |
| "if", |
| "so", |
| "but", |
| "because", |
| "couldnt", |
| "couldn't", |
| "dont", |
| "don't", |
| "doesnt", |
| "doesn't", |
| } |
| ) |
|
|
|
|
| def db_path() -> Path: |
| root = Path(__file__).resolve().parents[1] |
| return Path(os.environ.get("UNSPSC_DB_PATH", root / "data" / "unspsc.db")) |
|
|
|
|
| def connect() -> sqlite3.Connection: |
| p = db_path() |
| if not p.exists(): |
| raise FileNotFoundError(f"Catalogue database missing at {p}") |
| conn = sqlite3.connect(p, check_same_thread=False) |
| conn.row_factory = sqlite3.Row |
| return conn |
|
|
|
|
| def _keyword_tokens(q: str, *, max_tokens: int = 24) -> list[str]: |
| """Lowercase keywords with stopwords removed — better for FTS than raw sentence tokens.""" |
| raw = re.findall(r"[^\W_]+", q, flags=re.UNICODE) |
| seen: set[str] = set() |
| out: list[str] = [] |
| for t in raw: |
| tl = t.lower() |
| if len(tl) < 2 or tl in _STOPWORDS: |
| continue |
| if len(tl) > 48: |
| tl = tl[:48] |
| if tl in seen: |
| continue |
| seen.add(tl) |
| out.append(tl) |
| if len(out) >= max_tokens: |
| break |
| return out |
|
|
|
|
| def _fts_and_query(keywords: list[str], *, max_terms: int) -> str: |
| """Strict AND — keep term count low so rows can match.""" |
| if not keywords: |
| return "" |
| terms = keywords[: max(1, max_terms)] |
| return " AND ".join(terms) |
|
|
|
|
| def _fts_or_query(keywords: list[str], *, max_terms: int) -> str: |
| """Loose OR — any keyword can match (ranked when bm25 available).""" |
| if not keywords: |
| return "" |
| terms = keywords[: max(1, max_terms)] |
| return " OR ".join(terms) |
|
|
|
|
| def _fts_select_best_effort( |
| conn: sqlite3.Connection, |
| fts_q: str, |
| lim: int, |
| ) -> list[dict[str, Any]]: |
| cur = conn.cursor() |
| if not fts_q.strip(): |
| return [] |
| try: |
| cur.execute( |
| f""" |
| SELECT c.* |
| FROM commodities_fts |
| JOIN commodities c ON c.id = commodities_fts.rowid |
| WHERE commodities_fts MATCH ? |
| ORDER BY bm25(commodities_fts) |
| LIMIT {lim} |
| """, |
| (fts_q,), |
| ) |
| except sqlite3.OperationalError: |
| cur.execute( |
| f""" |
| SELECT c.* |
| FROM commodities_fts |
| JOIN commodities c ON c.id = commodities_fts.rowid |
| WHERE commodities_fts MATCH ? |
| LIMIT {lim} |
| """, |
| (fts_q,), |
| ) |
| return [row_to_dict(r) for r in cur.fetchall()] |
|
|
|
|
| def row_to_dict(r: sqlite3.Row) -> dict[str, Any]: |
| d = dict(r) |
| return { |
| **d, |
| "codes": { |
| "segment": d.get("segment"), |
| "family": d.get("family"), |
| "class": d.get("class"), |
| "commodity": d.get("commodity"), |
| }, |
| } |
|
|
|
|
| def _like_fallback( |
| conn: sqlite3.Connection, |
| keywords: list[str], |
| raw_fallback_tokens: list[str], |
| lim: int, |
| ) -> list[dict[str, Any]]: |
| """SQL LIKE across titles/path/definition — OR semantics per token.""" |
| tokens = keywords[:14] if keywords else raw_fallback_tokens[:14] |
| if not tokens: |
| return [] |
|
|
| cur = conn.cursor() |
| like_params: list[str] = [] |
| conds: list[str] = [] |
| for t in tokens: |
| pat = f"%{t}%" |
| like_params.extend([pat, pat, pat]) |
| conds.append( |
| "(c.path_titles LIKE ? OR c.commodity_title LIKE ? OR c.commodity_definition LIKE ?)" |
| ) |
| sql = f""" |
| SELECT DISTINCT c.* |
| FROM commodities c |
| WHERE ({' OR '.join(conds)}) |
| LIMIT {lim} |
| """ |
| cur.execute(sql, like_params) |
| return [row_to_dict(r) for r in cur.fetchall()] |
|
|
|
|
| def search_catalog(conn: sqlite3.Connection, query: str, limit: int = 25) -> list[dict[str, Any]]: |
| """Keyword search: relaxed FTS (AND → OR) then LIKE across commodity text columns.""" |
| lim = max(1, min(limit, 80)) |
| q = (query or "").strip() |
| if not q: |
| return [] |
|
|
| keywords = _keyword_tokens(q) |
| raw_tokens = [ |
| t.lower() |
| for t in re.findall(r"[^\W_]+", q, flags=re.UNICODE) |
| if len(t) >= 2 |
| ][:14] |
|
|
| |
| for n in (5, 4, 3): |
| fts_q = _fts_and_query(keywords, max_terms=n) |
| if fts_q: |
| out = _fts_select_best_effort(conn, fts_q, lim) |
| if out: |
| return out |
|
|
| |
| fts_or = _fts_or_query(keywords, max_terms=12) |
| if fts_or: |
| out = _fts_select_best_effort(conn, fts_or, lim) |
| if out: |
| return out |
|
|
| |
| return _like_fallback(conn, keywords, raw_tokens, lim) |
|
|
|
|
| def get_commodity(conn: sqlite3.Connection, commodity_code: int) -> dict[str, Any] | None: |
| cur = conn.cursor() |
| cur.execute( |
| """ |
| SELECT * FROM commodities WHERE commodity = ? LIMIT 1 |
| """, |
| (commodity_code,), |
| ) |
| r = cur.fetchone() |
| return row_to_dict(r) if r else None |
|
|
|
|
| def summarize_row(r: dict[str, Any]) -> dict[str, Any]: |
| """Compact agent-facing summary to reduce tokens.""" |
| codes = r.get("codes") or {} |
| return { |
| "commodity_code": codes.get("commodity"), |
| "path": r.get("path_titles") or "", |
| "segment_code": codes.get("segment"), |
| "family_code": codes.get("family"), |
| "class_code": codes.get("class"), |
| "commodity_title": r.get("commodity_title") or "", |
| "commodity_definition": (r.get("commodity_definition") or "")[:800], |
| } |
|
|