| import functools |
| import io |
| import sqlite3 |
|
|
| import numpy as np |
|
|
|
|
| class ItemDatabase: |
| def __init__(self, db_path): |
| sqlite3.register_converter("embedding", self._text_to_numpy_array) |
| self._db_path = db_path |
|
|
| @staticmethod |
| def _text_to_numpy_array(text): |
| out = io.BytesIO(text) |
| out.seek(0) |
| return np.load(out) |
|
|
| def _connect(self): |
| return sqlite3.connect( |
| self._db_path, detect_types=sqlite3.PARSE_DECLTYPES) |
|
|
| def search_items(self, query, n_items=10): |
| with self._connect() as conn: |
| c = conn.cursor() |
| c.execute(f"select item_id from items where title like '%{query}%'") |
| rows = c.fetchall()[:n_items] |
| return [row[0] for row in rows] |
|
|
| @functools.lru_cache(maxsize=2**14) |
| def get_item(self, item_id): |
| with self._connect() as conn: |
| c = conn.cursor() |
| c.row_factory = sqlite3.Row |
| c.execute(f"select * from items where item_id like '{item_id}'") |
| return c.fetchone() |