gyubin02 commited on
Commit
68f7921
·
1 Parent(s): 619dbf0
Files changed (2) hide show
  1. indexer.py +64 -5
  2. main.py +10 -1
indexer.py CHANGED
@@ -2,8 +2,9 @@
2
  from __future__ import annotations
3
 
4
  import argparse
 
5
  from pathlib import Path
6
- from typing import Iterable, List, Tuple, TypeVar
7
 
8
  import chromadb
9
  import torch
@@ -55,6 +56,12 @@ def parse_args() -> argparse.Namespace:
55
  default="maple_items",
56
  help="ChromaDB collection name.",
57
  )
 
 
 
 
 
 
58
  return parser.parse_args()
59
 
60
 
@@ -111,6 +118,52 @@ def load_images(
111
  return images, valid_paths, valid_ids
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def main() -> None:
115
  args = parse_args()
116
 
@@ -121,6 +174,8 @@ def main() -> None:
121
 
122
  ids = build_ids(image_paths)
123
  adapter_path = resolve_adapter_path(args.adapter_path)
 
 
124
 
125
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
 
@@ -159,10 +214,14 @@ def main() -> None:
159
  embeds = F.normalize(embeds, dim=-1)
160
 
161
  embeddings = embeds.detach().cpu().tolist()
162
- metadatas = [
163
- {"filepath": str(path.relative_to(args.data_dir).as_posix())}
164
- for path in valid_paths
165
- ]
 
 
 
 
166
 
167
  collection.upsert(
168
  ids=valid_ids,
 
2
  from __future__ import annotations
3
 
4
  import argparse
5
+ import json
6
  from pathlib import Path
7
+ from typing import Dict, Iterable, List, Optional, Tuple, TypeVar
8
 
9
  import chromadb
10
  import torch
 
56
  default="maple_items",
57
  help="ChromaDB collection name.",
58
  )
59
+ parser.add_argument(
60
+ "--labels-path",
61
+ type=Path,
62
+ default=None,
63
+ help="Path to labels.jsonl (defaults to data-dir/labels/labels.jsonl).",
64
+ )
65
  return parser.parse_args()
66
 
67
 
 
118
  return images, valid_paths, valid_ids
119
 
120
 
121
+ def normalize_label(value: Optional[str]) -> Optional[str]:
122
+ if value is None:
123
+ return None
124
+ if isinstance(value, str):
125
+ trimmed = value.strip()
126
+ return trimmed or None
127
+ return str(value)
128
+
129
+
130
+ def load_labels(labels_path: Path) -> Dict[str, Dict[str, str]]:
131
+ if not labels_path.exists():
132
+ print(f"Labels file not found, continuing without labels: {labels_path}")
133
+ return {}
134
+
135
+ label_map: Dict[str, Dict[str, str]] = {}
136
+ with labels_path.open("r", encoding="utf-8") as file:
137
+ for line_no, line in enumerate(file, start=1):
138
+ line = line.strip()
139
+ if not line:
140
+ continue
141
+ try:
142
+ record = json.loads(line)
143
+ except json.JSONDecodeError as exc:
144
+ print(f"Skipping label line {line_no}: {exc}")
145
+ continue
146
+
147
+ image_path = record.get("image_path")
148
+ if not image_path:
149
+ continue
150
+
151
+ item_name = normalize_label(record.get("item_name"))
152
+ label_ko = normalize_label(record.get("label_ko"))
153
+ if not item_name and not label_ko:
154
+ continue
155
+
156
+ normalized_path = Path(str(image_path)).as_posix().lstrip("./")
157
+ label_map[normalized_path] = {}
158
+ if item_name:
159
+ label_map[normalized_path]["item_name"] = item_name
160
+ if label_ko:
161
+ label_map[normalized_path]["label_ko"] = label_ko
162
+
163
+ print(f"Loaded labels for {len(label_map)} images from {labels_path}")
164
+ return label_map
165
+
166
+
167
  def main() -> None:
168
  args = parse_args()
169
 
 
174
 
175
  ids = build_ids(image_paths)
176
  adapter_path = resolve_adapter_path(args.adapter_path)
177
+ labels_path = args.labels_path or args.data_dir / "labels/labels.jsonl"
178
+ label_map = load_labels(labels_path)
179
 
180
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
181
 
 
214
  embeds = F.normalize(embeds, dim=-1)
215
 
216
  embeddings = embeds.detach().cpu().tolist()
217
+ metadatas = []
218
+ for path in valid_paths:
219
+ rel_path = path.relative_to(args.data_dir).as_posix()
220
+ metadata = {"filepath": rel_path}
221
+ label_data = label_map.get(rel_path)
222
+ if label_data:
223
+ metadata.update(label_data)
224
+ metadatas.append(metadata)
225
 
226
  collection.upsert(
227
  ids=valid_ids,
main.py CHANGED
@@ -100,7 +100,7 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
100
  results = collection.query(
101
  query_embeddings=[query_embedding],
102
  n_results=payload.k,
103
- include=["distances", "metadatas"],
104
  )
105
 
106
  ids: List[str] = results.get("ids", [[]])[0]
@@ -110,8 +110,14 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
110
  response_items = []
111
  for item_id, distance, metadata in zip(ids, distances, metadatas):
112
  filepath = ""
 
 
113
  if metadata:
114
  filepath = metadata.get("filepath", "")
 
 
 
 
115
  image_url = f"/static/images/{filepath}" if filepath else ""
116
  similarity = max(0.0, 1.0 - distance) if distance is not None else 0.0
117
  response_items.append(
@@ -121,6 +127,9 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
121
  "distance": distance,
122
  "similarity": similarity,
123
  "image_url": image_url,
 
 
 
124
  }
125
  )
126
 
 
100
  results = collection.query(
101
  query_embeddings=[query_embedding],
102
  n_results=payload.k,
103
+ include=["distances", "metadatas", "ids"],
104
  )
105
 
106
  ids: List[str] = results.get("ids", [[]])[0]
 
110
  response_items = []
111
  for item_id, distance, metadata in zip(ids, distances, metadatas):
112
  filepath = ""
113
+ item_name = ""
114
+ label_ko = ""
115
  if metadata:
116
  filepath = metadata.get("filepath", "")
117
+ item_name = metadata.get("item_name", "") or ""
118
+ label_ko = metadata.get("label_ko", "") or ""
119
+ if not item_name and filepath:
120
+ item_name = Path(filepath).stem
121
  image_url = f"/static/images/{filepath}" if filepath else ""
122
  similarity = max(0.0, 1.0 - distance) if distance is not None else 0.0
123
  response_items.append(
 
127
  "distance": distance,
128
  "similarity": similarity,
129
  "image_url": image_url,
130
+ "item_name": item_name,
131
+ "label_ko": label_ko,
132
+ "label": label_ko,
133
  }
134
  )
135