gyubin02 commited on
Commit
d390d1b
·
1 Parent(s): 4e7c9dd

keyword filter

Browse files
Files changed (3) hide show
  1. indexer.py +64 -20
  2. keyword_filters.py +118 -0
  3. main.py +63 -33
indexer.py CHANGED
@@ -14,25 +14,15 @@ from PIL import Image
14
  from tqdm import tqdm
15
  from transformers import SiglipModel, SiglipProcessor
16
 
 
 
 
 
 
 
 
17
  IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
18
  T = TypeVar("T")
19
- CATEGORY_SYNONYMS = {
20
- "모자": ["모자", "헬름", "헬멧", "햇", "보닛", "캡"],
21
- "신발": ["신발", "슈즈", "부츠", "샌들"],
22
- "장갑": ["장갑", "글러브"],
23
- "무기": ["무기", "검", "소드", "대검", "스태프", "완드", "활", "석궁", "창", "스피어", "폴암", "도끼", "단검", "너클", "건", "총", "클로"],
24
- "상의": ["상의", "셔츠", "자켓", "코트", "로브", "블라우스"],
25
- "하의": ["하의", "바지", "팬츠", "스커트"],
26
- "망토": ["망토", "케이프", "cape"],
27
- "귀걸이": ["귀걸이", "귀고리", "이어링"],
28
- "반지": ["반지", "링"],
29
- "목걸이": ["목걸이", "펜던트", "네클리스"],
30
- "벨트": ["벨트"],
31
- "얼굴장식": ["얼굴장식", "얼굴 장식"],
32
- "눈장식": ["눈장식", "눈 장식"],
33
- "보조무기": ["보조무기", "보조 무기"],
34
- "방패": ["방패", "쉴드", "실드"],
35
- }
36
 
37
 
38
  def parse_args() -> argparse.Namespace:
@@ -154,12 +144,39 @@ def detect_category(texts: List[str]) -> Optional[str]:
154
  return None
155
 
156
 
157
- def load_labels(labels_path: Path) -> Dict[str, Dict[str, str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  if not labels_path.exists():
159
  print(f"Labels file not found, continuing without labels: {labels_path}")
160
  return {}
161
 
162
- label_map: Dict[str, Dict[str, str]] = {}
163
  with labels_path.open("r", encoding="utf-8") as file:
164
  for line_no, line in enumerate(file, start=1):
165
  line = line.strip()
@@ -180,6 +197,15 @@ def load_labels(labels_path: Path) -> Dict[str, Dict[str, str]]:
180
  tags = record.get("tags_ko") or []
181
  tag_texts = [normalize_label(tag) for tag in tags if tag is not None]
182
  tag_texts = [tag for tag in tag_texts if tag]
 
 
 
 
 
 
 
 
 
183
  if not item_name and not label_ko and not tag_texts:
184
  continue
185
 
@@ -190,9 +216,27 @@ def load_labels(labels_path: Path) -> Dict[str, Dict[str, str]]:
190
  if label_ko:
191
  label_map[normalized_path]["label_ko"] = label_ko
192
  label_map[normalized_path]["label"] = label_ko
193
- category = detect_category([item_name or "", label_ko or "", *tag_texts])
 
 
 
 
 
 
 
 
194
  if category:
195
  label_map[normalized_path]["category"] = category
 
 
 
 
 
 
 
 
 
 
196
 
197
  print(f"Loaded labels for {len(label_map)} images from {labels_path}")
198
  return label_map
 
14
  from tqdm import tqdm
15
  from transformers import SiglipModel, SiglipProcessor
16
 
17
+ from keyword_filters import (
18
+ CATEGORY_SYNONYMS,
19
+ COLOR_SYNONYMS,
20
+ VIBE_SYNONYMS,
21
+ extract_keywords,
22
+ )
23
+
24
  IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
25
  T = TypeVar("T")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def parse_args() -> argparse.Namespace:
 
144
  return None
145
 
146
 
147
+ def collect_label_texts(
148
+ item_name: Optional[str],
149
+ label_ko: Optional[str],
150
+ tags: List[str],
151
+ query_variants: List[str],
152
+ attributes: Dict[str, object],
153
+ item_type_guess: Optional[str],
154
+ ) -> List[str]:
155
+ texts: List[str] = []
156
+ for value in (item_name, label_ko, item_type_guess):
157
+ if value:
158
+ texts.append(value)
159
+ texts.extend(tag for tag in tags if tag)
160
+ texts.extend(variant for variant in query_variants if variant)
161
+ for value in attributes.values():
162
+ if isinstance(value, list):
163
+ for entry in value:
164
+ entry_norm = normalize_label(entry)
165
+ if entry_norm:
166
+ texts.append(entry_norm)
167
+ else:
168
+ entry_norm = normalize_label(value)
169
+ if entry_norm:
170
+ texts.append(entry_norm)
171
+ return texts
172
+
173
+
174
+ def load_labels(labels_path: Path) -> Dict[str, Dict[str, object]]:
175
  if not labels_path.exists():
176
  print(f"Labels file not found, continuing without labels: {labels_path}")
177
  return {}
178
 
179
+ label_map: Dict[str, Dict[str, object]] = {}
180
  with labels_path.open("r", encoding="utf-8") as file:
181
  for line_no, line in enumerate(file, start=1):
182
  line = line.strip()
 
197
  tags = record.get("tags_ko") or []
198
  tag_texts = [normalize_label(tag) for tag in tags if tag is not None]
199
  tag_texts = [tag for tag in tag_texts if tag]
200
+ query_variants = record.get("query_variants_ko") or []
201
+ variant_texts = [
202
+ normalize_label(variant)
203
+ for variant in query_variants
204
+ if variant is not None
205
+ ]
206
+ variant_texts = [variant for variant in variant_texts if variant]
207
+ attributes = record.get("attributes") or {}
208
+ item_type_guess = normalize_label(attributes.get("item_type_guess"))
209
  if not item_name and not label_ko and not tag_texts:
210
  continue
211
 
 
216
  if label_ko:
217
  label_map[normalized_path]["label_ko"] = label_ko
218
  label_map[normalized_path]["label"] = label_ko
219
+ texts = collect_label_texts(
220
+ item_name,
221
+ label_ko,
222
+ tag_texts,
223
+ variant_texts,
224
+ attributes,
225
+ item_type_guess,
226
+ )
227
+ category = detect_category(texts)
228
  if category:
229
  label_map[normalized_path]["category"] = category
230
+ colors = extract_keywords(texts, COLOR_SYNONYMS)
231
+ if colors:
232
+ label_map[normalized_path]["colors"] = colors
233
+ for color in colors:
234
+ label_map[normalized_path][f"color_{color}"] = True
235
+ vibes = extract_keywords(texts, VIBE_SYNONYMS)
236
+ if vibes:
237
+ label_map[normalized_path]["vibes"] = vibes
238
+ for vibe in vibes:
239
+ label_map[normalized_path][f"vibe_{vibe}"] = True
240
 
241
  print(f"Loaded labels for {len(label_map)} images from {labels_path}")
242
  return label_map
keyword_filters.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Iterable, List
4
+
5
+
6
+ CATEGORY_SYNONYMS = {
7
+ "모자": ["모자", "헬름", "헬멧", "햇", "보닛", "캡", "hat", "cap", "helmet"],
8
+ "신발": ["신발", "슈즈", "부츠", "샌들", "shoes", "shoe", "boots", "sandal"],
9
+ "장갑": ["장갑", "글러브", "glove", "gloves"],
10
+ "무기": [
11
+ "무기",
12
+ "검",
13
+ "소드",
14
+ "대검",
15
+ "스태프",
16
+ "완드",
17
+ "활",
18
+ "석궁",
19
+ "창",
20
+ "스피어",
21
+ "폴암",
22
+ "도끼",
23
+ "단검",
24
+ "너클",
25
+ "건",
26
+ "총",
27
+ "클로",
28
+ "weapon",
29
+ "sword",
30
+ "staff",
31
+ "wand",
32
+ "bow",
33
+ "spear",
34
+ "axe",
35
+ "dagger",
36
+ "gun",
37
+ "claw",
38
+ ],
39
+ "상의": [
40
+ "상의",
41
+ "셔츠",
42
+ "자켓",
43
+ "코트",
44
+ "로브",
45
+ "블라우스",
46
+ "top",
47
+ "shirt",
48
+ "jacket",
49
+ "coat",
50
+ "robe",
51
+ "blouse",
52
+ ],
53
+ "하의": ["하의", "바지", "팬츠", "스커트", "bottom", "pants", "skirt"],
54
+ "망토": ["망토", "케이프", "cape", "날개", "윙", "wing", "wings"],
55
+ "귀걸이": ["귀걸이", "귀고리", "이어링", "earring", "earrings"],
56
+ "반지": ["반지", "링", "ring"],
57
+ "목걸이": ["목걸이", "펜던트", "네클리스", "necklace", "pendant"],
58
+ "벨트": ["벨트", "belt"],
59
+ "얼굴장식": ["얼굴장식", "얼굴 장식"],
60
+ "눈장식": ["눈장식", "눈 장식"],
61
+ "보조무기": ["보조무기", "보조 무기", "sub weapon", "subweapon", "offhand"],
62
+ "방패": ["방패", "쉴드", "실드", "shield"],
63
+ }
64
+
65
+ COLOR_SYNONYMS = {
66
+ "black": ["검은", "검정", "블랙", "black"],
67
+ "white": ["흰", "하얀", "화이트", "white"],
68
+ "gray": ["회색", "그레이", "gray"],
69
+ "silver": ["은색", "실버", "silver"],
70
+ "gold": ["금색", "골드", "gold"],
71
+ "red": ["빨간", "빨강", "레드", "적색", "붉은", "red"],
72
+ "pink": ["핑크", "분홍", "분홍색", "핑크색", "pink"],
73
+ "orange": ["주황", "오렌지", "orange"],
74
+ "yellow": ["노란", "노랑", "옐로", "yellow"],
75
+ "green": ["초록", "녹색", "그린", "green"],
76
+ "blue": ["파란", "파랑", "블루", "blue", "하늘색", "스카이", "sky"],
77
+ "purple": ["보라", "퍼플", "purple"],
78
+ "brown": ["갈색", "브라운", "brown"],
79
+ "beige": ["베이지", "beige"],
80
+ "mint": ["민트", "mint"],
81
+ "teal": ["청록", "teal", "터쿼이즈", "turquoise"],
82
+ "navy": ["남색", "네이비", "navy"],
83
+ }
84
+
85
+ VIBE_SYNONYMS = {
86
+ "cute": ["귀여움", "귀여운", "귀엽", "큐트", "cute", "사랑스러운", "lovely"],
87
+ "sporty": ["스포티", "스포츠", "sporty", "sports", "스포티한"],
88
+ "casual": ["캐주얼", "casual"],
89
+ "luxury": ["고급스러움", "고급", "luxury", "classy", "품격", "vip", "VIP"],
90
+ "elegant": ["우아", "elegant", "고상", "세련", "세련된"],
91
+ "playful": ["유쾌한", "funny", "playful", "장난", "발랄"],
92
+ "bright": ["빛나는", "sparkle", "glitter", "반짝", "sparkling"],
93
+ "powerful": ["강력한", "전투적인", "전투용", "powerful", "강인"],
94
+ "romantic": ["로맨틱", "romance", "romantic", "설렘", "사랑"],
95
+ "mysterious": ["신비", "mysterious", "묘한"],
96
+ "retro": ["레트로", "retro", "빈티지", "vintage", "클래식", "classic", "고전적인"],
97
+ "futuristic": ["futuristic", "미래", "사이버", "sf"],
98
+ "sweet": ["달달", "달콤", "sweet", "상큼"],
99
+ "unique": ["유니크", "unique", "독특", "개성", "특별"],
100
+ "calm": ["고요한", "차분", "calm"],
101
+ "dark": ["다크", "dark", "ダーク", "어두운"],
102
+ }
103
+
104
+
105
+ def extract_keywords(
106
+ texts: Iterable[str], synonyms: Dict[str, List[str]]
107
+ ) -> List[str]:
108
+ lowered_texts = [text.lower() for text in texts if text]
109
+ if not lowered_texts:
110
+ return []
111
+ hits: List[str] = []
112
+ for canonical, variants in synonyms.items():
113
+ for variant in variants:
114
+ variant_lower = variant.lower()
115
+ if any(variant_lower in text for text in lowered_texts):
116
+ hits.append(canonical)
117
+ break
118
+ return hits
main.py CHANGED
@@ -16,25 +16,14 @@ from peft import PeftModel
16
  from pydantic import BaseModel, Field
17
  from transformers import SiglipModel, SiglipProcessor
18
 
 
 
 
 
 
 
19
 
20
  DATA_DIR = (Path(__file__).resolve().parent / "data/2026-01-11").resolve()
21
- CATEGORY_SYNONYMS = {
22
- "모자": ["모자", "헬름", "헬멧", "햇", "보닛", "캡"],
23
- "신발": ["신발", "슈즈", "부츠", "샌들"],
24
- "장갑": ["장갑", "글러브"],
25
- "무기": ["무기", "검", "소드", "대검", "스태프", "완드", "활", "석궁", "창", "스피어", "폴암", "도끼", "단검", "너클", "건", "총", "클로"],
26
- "상의": ["상의", "셔츠", "자켓", "코트", "로브", "블라우스"],
27
- "하의": ["하의", "바지", "팬츠", "스커트"],
28
- "망토": ["망토", "케이프", "cape"],
29
- "귀걸이": ["귀걸이", "귀고리", "이어링"],
30
- "반지": ["반지", "링"],
31
- "목걸이": ["목걸이", "펜던트", "네클리스"],
32
- "벨트": ["벨트"],
33
- "얼굴장식": ["얼굴장식", "얼굴 장식"],
34
- "눈장식": ["눈장식", "눈 장식"],
35
- "보조무기": ["보조무기", "보조 무기"],
36
- "방패": ["방패", "쉴드", "실드"],
37
- }
38
 
39
 
40
  class SearchRequest(BaseModel):
@@ -51,21 +40,59 @@ def resolve_adapter_path(adapter_path: Path) -> Path:
51
  return adapter_path
52
 
53
 
54
- def extract_category_keywords(query: str) -> List[str]:
55
- keywords: List[str] = []
56
- lowered_query = query.lower()
57
- for category, variants in CATEGORY_SYNONYMS.items():
58
- for variant in variants:
59
- if variant.lower() in lowered_query and category not in keywords:
60
- keywords.append(category)
61
- break
62
- return keywords
63
 
64
 
65
- def build_metadata_filter(keywords: List[str]) -> Dict[str, Any] | None:
66
- if not keywords:
 
 
 
 
 
 
 
 
 
67
  return None
68
- return {"category": {"$in": keywords}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  @asynccontextmanager
@@ -150,11 +177,11 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
150
 
151
  query_embedding = text_embeds[0].detach().cpu().tolist()
152
 
153
- filter_keywords = extract_category_keywords(query)
154
- where_filter = build_metadata_filter(filter_keywords)
155
 
156
  results = None
157
- if where_filter:
158
  try:
159
  results = collection.query(
160
  query_embeddings=[query_embedding],
@@ -163,8 +190,11 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
163
  include=["distances", "metadatas"],
164
  )
165
  except Exception as exc: # noqa: BLE001
166
- print(f"Filtered query failed ({exc}); falling back to vector-only.")
167
  results = None
 
 
 
168
 
169
  if not results or not results.get("ids") or not results["ids"][0]:
170
  results = collection.query(
 
16
  from pydantic import BaseModel, Field
17
  from transformers import SiglipModel, SiglipProcessor
18
 
19
+ from keyword_filters import (
20
+ CATEGORY_SYNONYMS,
21
+ COLOR_SYNONYMS,
22
+ VIBE_SYNONYMS,
23
+ extract_keywords,
24
+ )
25
 
26
  DATA_DIR = (Path(__file__).resolve().parent / "data/2026-01-11").resolve()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  class SearchRequest(BaseModel):
 
40
  return adapter_path
41
 
42
 
43
+ def extract_query_filters(query: str) -> Dict[str, List[str]]:
44
+ texts = [query]
45
+ return {
46
+ "categories": extract_keywords(texts, CATEGORY_SYNONYMS),
47
+ "colors": extract_keywords(texts, COLOR_SYNONYMS),
48
+ "vibes": extract_keywords(texts, VIBE_SYNONYMS),
49
+ }
 
 
50
 
51
 
52
+ def build_where_filter(
53
+ categories: List[str], colors: List[str], vibes: List[str]
54
+ ) -> Dict[str, Any] | None:
55
+ clauses: List[Dict[str, Any]] = []
56
+ if categories:
57
+ clauses.append({"category": {"$in": categories}})
58
+ if colors:
59
+ clauses.append({"$and": [{f"color_{color}": True} for color in colors]})
60
+ if vibes:
61
+ clauses.append({"$and": [{f"vibe_{vibe}": True} for vibe in vibes]})
62
+ if not clauses:
63
  return None
64
+ if len(clauses) == 1:
65
+ return clauses[0]
66
+ return {"$and": clauses}
67
+
68
+
69
+ def build_filter_candidates(filters: Dict[str, List[str]]) -> List[Dict[str, Any]]:
70
+ parts = {
71
+ "category": filters.get("categories") or [],
72
+ "color": filters.get("colors") or [],
73
+ "vibe": filters.get("vibes") or [],
74
+ }
75
+ candidates: List[Dict[str, Any]] = []
76
+ combos = [
77
+ ("category", "color", "vibe"),
78
+ ("category", "color"),
79
+ ("category", "vibe"),
80
+ ("color", "vibe"),
81
+ ("category",),
82
+ ("color",),
83
+ ("vibe",),
84
+ ]
85
+ for combo in combos:
86
+ if not all(parts[facet] for facet in combo):
87
+ continue
88
+ where_filter = build_where_filter(
89
+ parts["category"] if "category" in combo else [],
90
+ parts["color"] if "color" in combo else [],
91
+ parts["vibe"] if "vibe" in combo else [],
92
+ )
93
+ if where_filter:
94
+ candidates.append(where_filter)
95
+ return candidates
96
 
97
 
98
  @asynccontextmanager
 
177
 
178
  query_embedding = text_embeds[0].detach().cpu().tolist()
179
 
180
+ filter_parts = extract_query_filters(query)
181
+ where_candidates = build_filter_candidates(filter_parts)
182
 
183
  results = None
184
+ for where_filter in where_candidates:
185
  try:
186
  results = collection.query(
187
  query_embeddings=[query_embedding],
 
190
  include=["distances", "metadatas"],
191
  )
192
  except Exception as exc: # noqa: BLE001
193
+ print(f"Filtered query failed ({exc}); trying less strict.")
194
  results = None
195
+ continue
196
+ if results and results.get("ids") and results["ids"][0]:
197
+ break
198
 
199
  if not results or not results.get("ids") or not results["ids"][0]:
200
  results = collection.query(