File size: 6,788 Bytes
31d61ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""모두의 빛길 — 한국어 텍스트 임베딩 인코더.

sentence-transformers 기반 한국어 임베딩으로 자연어 질문과 시설을 같은
벡터 공간에 매핑한다. 슬롯 추출이 실패하거나 부분적이어도 의미 기반
검색이 가능하게 해주는 Fallback이자 보조 시그널.

권장 모델
---------
- jhgan/ko-sroberta-multitask  (~120MB, 빠름, 한국어 검색 표준)
- BAAI/bge-m3                  (~300MB, 다국어, 더 정확)

HF Spaces 메모리 한계를 고려하면 ko-sroberta-multitask가 안전.

GPU 설정은 사용자가 별도 셀에서 진행한다(이 모듈은 device 지정 X).
"""
from __future__ import annotations
from typing import List, Optional, Union
from functools import lru_cache

import torch

from .data_schema import VenueNode


# ============================================================
# 한국어 텍스트 인코더
# ============================================================

class KoreanTextEncoder:
    """sentence-transformers 한국어 임베딩 래퍼.

    인터페이스
    ---------
    encode(texts)        : 문자열 또는 리스트 → (N, D) 텐서 (정규화됨)
    encode_venue(venue)  : VenueNode 1개 → (D,) 텐서
    encode_venues(venues): VenueNode 리스트 → (N, D) 텐서
    similarity(q, vs)    : 질문 벡터와 시설 벡터들 간 코사인 유사도 (이미 정규화됨이라 dot product)
    """

    def __init__(self, model_name: str = "jhgan/ko-sroberta-multitask",
                 device: Optional[str] = None):
        try:
            from sentence_transformers import SentenceTransformer
        except ImportError as e:
            raise ImportError(
                "sentence-transformers가 필요합니다: "
                "pip install sentence-transformers"
            ) from e

        # device 인자가 None이면 sentence_transformers가 자동으로 cuda/cpu 결정
        self.model = SentenceTransformer(model_name, device=device)
        self.dim = self.model.get_sentence_embedding_dimension()
        self.model_name = model_name

    def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
        """문자열 1개 또는 리스트 → 정규화된 임베딩 텐서.

        반환 shape:
            str    → (D,)
            List[str] → (N, D)
        """
        if isinstance(texts, str):
            return self.model.encode(
                texts, convert_to_tensor=True, normalize_embeddings=True,
            )
        return self.model.encode(
            texts, convert_to_tensor=True, normalize_embeddings=True,
            show_progress_bar=False,
        )

    def encode_venue(self, venue: VenueNode) -> torch.Tensor:
        return self.encode(self._venue_to_text(venue))

    def encode_venues(self, venues: List[VenueNode]) -> torch.Tensor:
        """시설 리스트를 벡터화. 호출 1회로 모든 시설 처리 (배치)."""
        if not venues:
            return torch.zeros((0, self.dim))
        texts = [self._venue_to_text(v) for v in venues]
        return self.encode(texts)

    def similarity(self, q_vec: torch.Tensor, v_mat: torch.Tensor) -> torch.Tensor:
        """이미 정규화된 벡터들 간 코사인 유사도 = 내적.

        q_vec : (D,) 또는 (B, D)
        v_mat : (N, D)
        반환: (N,) 또는 (B, N)
        """
        if q_vec.dim() == 1:
            return v_mat @ q_vec
        return q_vec @ v_mat.t()

    # --------------------------------------------------------
    # 시설 → 텍스트 변환
    # --------------------------------------------------------

    @staticmethod
    def _venue_to_text(v: VenueNode) -> str:
        """시설 1개를 검색용 텍스트로.

        포함 정보: 이름, 카테고리, 접근성, 실내·외, 가격, 어르신 친화 등
        예: "국립아시아문화전당 공연장 엘리베이터 있음 장애인화장실 실내 유료"
        """
        parts: List[str] = [v.name, v.venue_type]

        if v.has_elevator:
            parts.append("엘리베이터")
        if v.has_disabled_toilet:
            parts.append("장애인화장실")
        if v.has_ramp:
            parts.append("경사로")
        if v.indoor:
            parts.append("실내")
        else:
            parts.append("야외")
        if v.free:
            parts.append("무료")
        if v.age_friendly:
            parts.append("어르신 친화")

        # 카테고리 동의어를 같이 넣어 임베딩 매칭 폭 넓힘
        cat_synonyms = {
            "박물관":   ["역사관", "유물"],
            "미술관":   ["갤러리", "전시", "그림", "예술"],
            "도서관":   ["책", "독서", "열람실"],
            "공연장":   ["공연", "콘서트홀", "극장", "공연관"],
            "체험관":   ["체험", "핸즈온"],
            "카페":     ["커피", "차"],
            "문화센터": ["문화", "프로그램"],
            "공원":     ["공원", "산책", "쉼터"],
        }
        if v.venue_type in cat_synonyms:
            parts.extend(cat_synonyms[v.venue_type])

        return " ".join(parts)

    # --------------------------------------------------------
    # 질문 → 텍스트 정규화 (옵션)
    # --------------------------------------------------------

    @staticmethod
    def normalize_query(question: str) -> str:
        """질문 임베딩 전에 가벼운 정규화. 현재는 그대로 반환.
        필요하면 형태소 분석이나 동의어 치환 추가."""
        return question.strip()


# ============================================================
# 캐시된 시설 임베딩 (Pipeline에서 재사용)
# ============================================================

class VenueEmbeddingCache:
    """시설 임베딩을 한 번만 계산하고 재사용.

    Pipeline 인스턴스 생성 시 미리 계산해두면 매 추천 호출마다
    수백 개 시설을 다시 임베딩하지 않아 빠르다.
    """

    def __init__(self, encoder: KoreanTextEncoder, venues: List[VenueNode]):
        self.encoder = encoder
        self.venues = venues
        self.embeddings = encoder.encode_venues(venues)   # (N, D)
        self.id_to_idx = {v.id: i for i, v in enumerate(venues)}

    def get(self, venue_ids: List[int]) -> torch.Tensor:
        idx = [self.id_to_idx[vid] for vid in venue_ids if vid in self.id_to_idx]
        if not idx:
            return torch.zeros((0, self.embeddings.size(-1)))
        return self.embeddings[idx]

    def all(self) -> torch.Tensor:
        return self.embeddings

    def to(self, device) -> "VenueEmbeddingCache":
        self.embeddings = self.embeddings.to(device)
        return self


__all__ = ["KoreanTextEncoder", "VenueEmbeddingCache"]