modu-lightway / src /text_encoder.py
Munsusu's picture
Upload 3 files
31d61ad verified
"""모두의 빛길 — 한국어 텍스트 임베딩 인코더.
sentence-transformers 기반 한국어 임베딩으로 자연어 질문과 시설을 같은
벡터 공간에 매핑한다. 슬롯 추출이 실패하거나 부분적이어도 의미 기반
검색이 가능하게 해주는 Fallback이자 보조 시그널.
권장 모델
---------
- jhgan/ko-sroberta-multitask (~120MB, 빠름, 한국어 검색 표준)
- BAAI/bge-m3 (~300MB, 다국어, 더 정확)
HF Spaces 메모리 한계를 고려하면 ko-sroberta-multitask가 안전.
GPU 설정은 사용자가 별도 셀에서 진행한다(이 모듈은 device 지정 X).
"""
from __future__ import annotations
from typing import List, Optional, Union
from functools import lru_cache
import torch
from .data_schema import VenueNode
# ============================================================
# 한국어 텍스트 인코더
# ============================================================
class KoreanTextEncoder:
"""sentence-transformers 한국어 임베딩 래퍼.
인터페이스
---------
encode(texts) : 문자열 또는 리스트 → (N, D) 텐서 (정규화됨)
encode_venue(venue) : VenueNode 1개 → (D,) 텐서
encode_venues(venues): VenueNode 리스트 → (N, D) 텐서
similarity(q, vs) : 질문 벡터와 시설 벡터들 간 코사인 유사도 (이미 정규화됨이라 dot product)
"""
def __init__(self, model_name: str = "jhgan/ko-sroberta-multitask",
device: Optional[str] = None):
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise ImportError(
"sentence-transformers가 필요합니다: "
"pip install sentence-transformers"
) from e
# device 인자가 None이면 sentence_transformers가 자동으로 cuda/cpu 결정
self.model = SentenceTransformer(model_name, device=device)
self.dim = self.model.get_sentence_embedding_dimension()
self.model_name = model_name
def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
"""문자열 1개 또는 리스트 → 정규화된 임베딩 텐서.
반환 shape:
str → (D,)
List[str] → (N, D)
"""
if isinstance(texts, str):
return self.model.encode(
texts, convert_to_tensor=True, normalize_embeddings=True,
)
return self.model.encode(
texts, convert_to_tensor=True, normalize_embeddings=True,
show_progress_bar=False,
)
def encode_venue(self, venue: VenueNode) -> torch.Tensor:
return self.encode(self._venue_to_text(venue))
def encode_venues(self, venues: List[VenueNode]) -> torch.Tensor:
"""시설 리스트를 벡터화. 호출 1회로 모든 시설 처리 (배치)."""
if not venues:
return torch.zeros((0, self.dim))
texts = [self._venue_to_text(v) for v in venues]
return self.encode(texts)
def similarity(self, q_vec: torch.Tensor, v_mat: torch.Tensor) -> torch.Tensor:
"""이미 정규화된 벡터들 간 코사인 유사도 = 내적.
q_vec : (D,) 또는 (B, D)
v_mat : (N, D)
반환: (N,) 또는 (B, N)
"""
if q_vec.dim() == 1:
return v_mat @ q_vec
return q_vec @ v_mat.t()
# --------------------------------------------------------
# 시설 → 텍스트 변환
# --------------------------------------------------------
@staticmethod
def _venue_to_text(v: VenueNode) -> str:
"""시설 1개를 검색용 텍스트로.
포함 정보: 이름, 카테고리, 접근성, 실내·외, 가격, 어르신 친화 등
예: "국립아시아문화전당 공연장 엘리베이터 있음 장애인화장실 실내 유료"
"""
parts: List[str] = [v.name, v.venue_type]
if v.has_elevator:
parts.append("엘리베이터")
if v.has_disabled_toilet:
parts.append("장애인화장실")
if v.has_ramp:
parts.append("경사로")
if v.indoor:
parts.append("실내")
else:
parts.append("야외")
if v.free:
parts.append("무료")
if v.age_friendly:
parts.append("어르신 친화")
# 카테고리 동의어를 같이 넣어 임베딩 매칭 폭 넓힘
cat_synonyms = {
"박물관": ["역사관", "유물"],
"미술관": ["갤러리", "전시", "그림", "예술"],
"도서관": ["책", "독서", "열람실"],
"공연장": ["공연", "콘서트홀", "극장", "공연관"],
"체험관": ["체험", "핸즈온"],
"카페": ["커피", "차"],
"문화센터": ["문화", "프로그램"],
"공원": ["공원", "산책", "쉼터"],
}
if v.venue_type in cat_synonyms:
parts.extend(cat_synonyms[v.venue_type])
return " ".join(parts)
# --------------------------------------------------------
# 질문 → 텍스트 정규화 (옵션)
# --------------------------------------------------------
@staticmethod
def normalize_query(question: str) -> str:
"""질문 임베딩 전에 가벼운 정규화. 현재는 그대로 반환.
필요하면 형태소 분석이나 동의어 치환 추가."""
return question.strip()
# ============================================================
# 캐시된 시설 임베딩 (Pipeline에서 재사용)
# ============================================================
class VenueEmbeddingCache:
"""시설 임베딩을 한 번만 계산하고 재사용.
Pipeline 인스턴스 생성 시 미리 계산해두면 매 추천 호출마다
수백 개 시설을 다시 임베딩하지 않아 빠르다.
"""
def __init__(self, encoder: KoreanTextEncoder, venues: List[VenueNode]):
self.encoder = encoder
self.venues = venues
self.embeddings = encoder.encode_venues(venues) # (N, D)
self.id_to_idx = {v.id: i for i, v in enumerate(venues)}
def get(self, venue_ids: List[int]) -> torch.Tensor:
idx = [self.id_to_idx[vid] for vid in venue_ids if vid in self.id_to_idx]
if not idx:
return torch.zeros((0, self.embeddings.size(-1)))
return self.embeddings[idx]
def all(self) -> torch.Tensor:
return self.embeddings
def to(self, device) -> "VenueEmbeddingCache":
self.embeddings = self.embeddings.to(device)
return self
__all__ = ["KoreanTextEncoder", "VenueEmbeddingCache"]