mbtibooktalk / server /semantic_engine.py
dongle0516's picture
feat: ์‚ฌ์„œ์—๊ฒŒ ์งˆ๋ฌธํ•˜๊ธฐ ์ข…ํ•ฉ ๊ฐœ์„  (1์ฐจ~4์ฐจ ๋ˆ„์ )
1fa3750
"""ํ•œ๊ตญ์–ด ์‹œ๋งจํ‹ฑ ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰ ์—”์ง„.
- sentence-transformers + ko-sroberta-multitask ์‚ฌ์šฉ
- ๋ชจ๋ธ ๋ฏธ์„ค์น˜ยท๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ ๋“ฑ ํ™˜๊ฒฝ ๋ฌธ์ œ ๋ฐœ์ƒ ์‹œ ์ž๋™์œผ๋กœ ๋น„ํ™œ์„ฑํ™”
- ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ์€ ๋””์Šคํฌ์— ์บ์‹œ(.npz)ํ•˜์—ฌ ์žฌ์‹œ์ž‘ ์‹œ ์ฆ‰์‹œ ๋กœ๋“œ
ENV
---
SEMANTIC_MODEL : ํ•œ๊ตญ์–ด ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๋ช… (๊ธฐ๋ณธ: jhgan/ko-sroberta-multitask)
SEMANTIC_ENABLED : "0" ์œผ๋กœ ์„ค์ • ์‹œ ๊ฐ•์ œ ๋น„ํ™œ์„ฑํ™”
"""
from __future__ import annotations
import os
import pickle
from pathlib import Path
from typing import Optional
import numpy as np
from .database import get_conn, get_items_for_index
MODEL_NAME = os.getenv("SEMANTIC_MODEL", "jhgan/ko-sroberta-multitask")
CACHE_PATH = Path(__file__).parent / "semantic_cache.pkl"
DISABLED = os.getenv("SEMANTIC_ENABLED", "1") == "0"
_model = None
_embeddings: Optional[np.ndarray] = None
_ids: list[int] = []
_rec_keys: list[str] = []
_load_failed: bool = False
def _try_load_model():
"""sentence-transformers ๋ชจ๋ธ์„ ์ง€์—ฐ ๋กœ๋“œ. ์‹คํŒจ ์‹œ None ๋ฐ˜ํ™˜."""
global _model, _load_failed
if DISABLED or _load_failed:
return None
if _model is not None:
return _model
try:
from sentence_transformers import SentenceTransformer # type: ignore
_model = SentenceTransformer(MODEL_NAME)
return _model
except Exception as e:
_load_failed = True
print(f"[semantic_engine] ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ โ†’ TF-IDF ํด๋ฐฑ: {e}")
return None
def is_available() -> bool:
"""์‹œ๋งจํ‹ฑ ๊ฒ€์ƒ‰ ๊ฐ€์šฉ ์—ฌ๋ถ€."""
if DISABLED:
return False
if _embeddings is not None:
return True
return _try_load_model() is not None
def build_index() -> bool:
"""์ „์ฒด Q&A์— ๋Œ€ํ•ด ์ž„๋ฒ ๋”ฉ์„ ์ƒ์„ฑํ•˜์—ฌ ๋””์Šคํฌ ์บ์‹œ."""
global _embeddings, _ids, _rec_keys
model = _try_load_model()
if model is None:
return False
rows = get_items_for_index()
if not rows:
return False
_ids = [r["id"] for r in rows]
_rec_keys = [r["rec_key"] for r in rows]
texts = [f"{r['question']} {r['answer']}"[:1024] for r in rows]
print(f"[semantic_engine] {len(texts)}๊ฑด ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ์ค‘...")
embs = model.encode(texts, batch_size=32, show_progress_bar=False,
convert_to_numpy=True, normalize_embeddings=True)
_embeddings = embs.astype(np.float32)
with open(CACHE_PATH, "wb") as f:
pickle.dump({"ids": _ids, "rec_keys": _rec_keys,
"embeddings": _embeddings, "model": MODEL_NAME}, f)
print(f"[semantic_engine] ์ž„๋ฒ ๋”ฉ ์บ์‹œ ์ €์žฅ ์™„๋ฃŒ ({_embeddings.shape})")
return True
def _ensure_loaded() -> bool:
global _embeddings, _ids, _rec_keys
if _embeddings is not None:
return True
if not CACHE_PATH.exists():
return False
try:
with open(CACHE_PATH, "rb") as f:
data = pickle.load(f)
_embeddings = data["embeddings"]
_ids = data["ids"]
_rec_keys = data["rec_keys"]
return True
except Exception as e:
print(f"[semantic_engine] ์บ์‹œ ๋กœ๋“œ ์‹คํŒจ: {e}")
return False
def semantic_search(query: str, top_k: int = 10) -> list[dict]:
"""์‹œ๋งจํ‹ฑ ์ž„๋ฒ ๋”ฉ ๊ธฐ๋ฐ˜ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ฒ€์ƒ‰."""
if not is_available():
return []
if not _ensure_loaded():
return []
model = _try_load_model()
if model is None:
return []
q_vec = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]
scores = (_embeddings @ q_vec).astype(np.float32)
top_idx = np.argsort(scores)[::-1][: top_k * 2]
top_idx = [i for i in top_idx if scores[i] >= 0.2][:top_k]
if not top_idx:
return []
id_list = [_ids[i] for i in top_idx]
placeholders = ",".join("?" * len(id_list))
with get_conn() as conn:
rows = conn.execute(
f"SELECT id, rec_key, question, answer, subject, answer_date, answer_lib "
f"FROM qa_items WHERE id IN ({placeholders})",
id_list,
).fetchall()
by_id = {r["id"]: dict(r) for r in rows}
results = []
for i in top_idx:
row = by_id.get(_ids[i])
if row:
results.append({**row, "score": float(scores[i]), "search_type": "semantic"})
return results
def invalidate_cache() -> None:
global _embeddings, _ids, _rec_keys
_embeddings = None
_ids = []
_rec_keys = []
if CACHE_PATH.exists():
CACHE_PATH.unlink()