| from __future__ import annotations |
|
|
| import json |
| import os |
| from dataclasses import dataclass |
| from typing import Any, Dict, List |
|
|
| from .util import simple_jp_tokenize |
|
|
|
|
| @dataclass |
| class DBHit: |
| score: float |
| item: Dict[str, Any] |
|
|
|
|
| class LatticeDB: |
| def __init__(self, path: str): |
| self.path = path |
| os.makedirs(os.path.dirname(self.path), exist_ok=True) |
| self._data: Dict[str, Any] = {"items": []} |
| self._load() |
|
|
| def _load(self) -> None: |
| if not os.path.exists(self.path): |
| self._flush() |
| return |
| try: |
| with open(self.path, "r", encoding="utf-8") as f: |
| self._data = json.load(f) |
| if "items" not in self._data or not isinstance(self._data["items"], list): |
| self._data = {"items": []} |
| except Exception: |
| self._data = {"items": []} |
|
|
| def _flush(self) -> None: |
| with open(self.path, "w", encoding="utf-8") as f: |
| json.dump(self._data, f, ensure_ascii=False, indent=2) |
|
|
| def add(self, item: Dict[str, Any]) -> None: |
| self._data["items"].append(item) |
| self._flush() |
|
|
| def search(self, query: str, required_tags: List[str] | None = None, exclude_types: List[str] | None = None, k: int = 8) -> List[DBHit]: |
| required_tags = required_tags or [] |
| exclude_types = exclude_types or [] |
| q_tokens = set(simple_jp_tokenize(query)) |
|
|
| hits: List[DBHit] = [] |
| for it in self._data.get("items", []): |
| it_type = str(it.get("type", "")) |
| if it_type in exclude_types: |
| continue |
|
|
| tags = it.get("tags") or [] |
| if required_tags and not all(t in tags for t in required_tags): |
| continue |
|
|
| trig = it.get("trigger") or [] |
| trig_tokens = set([str(x) for x in trig]) |
|
|
| overlap = len(q_tokens & trig_tokens) |
| if overlap <= 0: |
| continue |
|
|
| |
| if overlap < 2 and required_tags: |
| continue |
|
|
| hits.append(DBHit(score=float(overlap), item=it)) |
|
|
| hits.sort(key=lambda h: h.score, reverse=True) |
| return hits[:k] |