maple-data / indexer.py
gyubin02's picture
.
2b2c6ba
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar
import chromadb
import torch
import torch.nn.functional as F
from peft import PeftModel
from PIL import Image
from tqdm import tqdm
from transformers import SiglipModel, SiglipProcessor
from keyword_filters import (
CATEGORY_SYNONYMS,
COLOR_SYNONYMS,
VIBE_SYNONYMS,
extract_keywords,
)
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
T = TypeVar("T")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Index images with SigLIP + LoRA embeddings into ChromaDB."
)
parser.add_argument(
"--data-dir",
type=Path,
default=Path("data/2026-01-11"),
help="Root directory containing images to index.",
)
parser.add_argument(
"--model-id",
default="google/siglip-base-patch16-256-multilingual",
help="Base SigLIP model ID.",
)
parser.add_argument(
"--adapter-path",
type=Path,
default=Path("outputs/ko-clip-lora"),
help="Path to the LoRA adapter directory.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size for image embedding.",
)
parser.add_argument(
"--chroma-path",
type=Path,
default=Path("chroma_db"),
help="Directory for ChromaDB persistent storage.",
)
parser.add_argument(
"--collection",
default="maple_items",
help="ChromaDB collection name.",
)
parser.add_argument(
"--labels-path",
type=Path,
default=None,
help="Path to labels.jsonl (defaults to data-dir/labels/labels.jsonl).",
)
return parser.parse_args()
def resolve_adapter_path(adapter_path: Path) -> Path:
if (adapter_path / "adapter_config.json").exists():
return adapter_path
candidate = adapter_path / "best_model"
if (candidate / "adapter_config.json").exists():
return candidate
return adapter_path
def find_images(root: Path) -> List[Path]:
if not root.exists():
raise FileNotFoundError(f"Data directory not found: {root}")
images = [
path
for path in root.rglob("*")
if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS
]
return sorted(images)
def batch_iter(items: List[T], batch_size: int) -> Iterable[List[T]]:
for start in range(0, len(items), batch_size):
yield items[start : start + batch_size]
def build_ids(paths: List[Path]) -> List[str]:
counts = {}
ids: List[str] = []
for path in paths:
stem = path.stem
count = counts.get(stem, 0)
ids.append(stem if count == 0 else f"{stem}_{count}")
counts[stem] = count + 1
return ids
def load_images(
paths: List[Path], ids: List[str]
) -> Tuple[List[Image.Image], List[Path], List[str]]:
images: List[Image.Image] = []
valid_paths: List[Path] = []
valid_ids: List[str] = []
for path, item_id in zip(paths, ids):
try:
with Image.open(path) as image:
images.append(image.convert("RGB"))
valid_paths.append(path)
valid_ids.append(item_id)
except Exception as exc: # noqa: BLE001
print(f"Skipping unreadable image: {path} ({exc})")
return images, valid_paths, valid_ids
def normalize_label(value: Optional[str]) -> Optional[str]:
if value is None:
return None
if isinstance(value, str):
trimmed = value.strip()
return trimmed or None
return str(value)
def detect_category(texts: List[str]) -> Optional[str]:
lowered_texts = [text.lower() for text in texts if text]
for category, keywords in CATEGORY_SYNONYMS.items():
for keyword in keywords:
keyword_lower = keyword.lower()
if any(keyword_lower in text for text in lowered_texts):
return category
return None
def collect_label_texts(
item_name: Optional[str],
label_ko: Optional[str],
tags: List[str],
query_variants: List[str],
attributes: Dict[str, object],
item_type_guess: Optional[str],
) -> List[str]:
texts: List[str] = []
for value in (item_name, label_ko, item_type_guess):
if value:
texts.append(value)
texts.extend(tag for tag in tags if tag)
texts.extend(variant for variant in query_variants if variant)
for value in attributes.values():
if isinstance(value, list):
for entry in value:
entry_norm = normalize_label(entry)
if entry_norm:
texts.append(entry_norm)
else:
entry_norm = normalize_label(value)
if entry_norm:
texts.append(entry_norm)
return texts
def load_labels(labels_path: Path) -> Dict[str, Dict[str, object]]:
if not labels_path.exists():
print(f"Labels file not found, continuing without labels: {labels_path}")
return {}
label_map: Dict[str, Dict[str, object]] = {}
with labels_path.open("r", encoding="utf-8") as file:
for line_no, line in enumerate(file, start=1):
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError as exc:
print(f"Skipping label line {line_no}: {exc}")
continue
image_path = record.get("image_path")
if not image_path:
continue
item_name = normalize_label(record.get("item_name"))
label_ko = normalize_label(record.get("label_ko"))
tags = record.get("tags_ko") or []
tag_texts = [normalize_label(tag) for tag in tags if tag is not None]
tag_texts = [tag for tag in tag_texts if tag]
query_variants = record.get("query_variants_ko") or []
variant_texts = [
normalize_label(variant)
for variant in query_variants
if variant is not None
]
variant_texts = [variant for variant in variant_texts if variant]
attributes = record.get("attributes") or {}
item_type_guess = normalize_label(attributes.get("item_type_guess"))
if not item_name and not label_ko and not tag_texts:
continue
normalized_path = Path(str(image_path)).as_posix().lstrip("./")
label_map[normalized_path] = {}
if item_name:
label_map[normalized_path]["item_name"] = item_name
if label_ko:
label_map[normalized_path]["label_ko"] = label_ko
label_map[normalized_path]["label"] = label_ko
texts = collect_label_texts(
item_name,
label_ko,
tag_texts,
variant_texts,
attributes,
item_type_guess,
)
category = detect_category(texts)
if category:
label_map[normalized_path]["category"] = category
colors = extract_keywords(texts, COLOR_SYNONYMS)
if colors:
for color in colors:
label_map[normalized_path][f"color_{color}"] = True
vibes = extract_keywords(texts, VIBE_SYNONYMS)
if vibes:
for vibe in vibes:
label_map[normalized_path][f"vibe_{vibe}"] = True
print(f"Loaded labels for {len(label_map)} images from {labels_path}")
return label_map
def main() -> None:
args = parse_args()
image_paths = find_images(args.data_dir)
if not image_paths:
print(f"No images found under {args.data_dir}")
return
ids = build_ids(image_paths)
adapter_path = resolve_adapter_path(args.adapter_path)
labels_path = args.labels_path or args.data_dir / "labels/labels.jsonl"
label_map = load_labels(labels_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Loading model and processor...")
base_model = SiglipModel.from_pretrained(args.model_id)
model = PeftModel.from_pretrained(base_model, str(adapter_path))
processor = SiglipProcessor.from_pretrained(args.model_id)
model.to(device)
model.eval()
client = chromadb.PersistentClient(path=str(args.chroma_path))
collection = client.get_or_create_collection(
name=args.collection,
metadata={"hnsw:space": "cosine"},
)
total_images = len(image_paths)
progress = tqdm(total=total_images, desc="Indexing images", unit="img")
indexed_count = 0
with torch.no_grad():
for batch_paths, batch_ids in zip(
batch_iter(image_paths, args.batch_size),
batch_iter(ids, args.batch_size),
):
images, valid_paths, valid_ids = load_images(batch_paths, batch_ids)
if not images:
progress.update(len(batch_paths))
continue
inputs = processor(images=images, return_tensors="pt")
inputs = {key: value.to(device) for key, value in inputs.items()}
embeds = model.get_image_features(**inputs)
embeds = F.normalize(embeds, dim=-1)
embeddings = embeds.detach().cpu().tolist()
metadatas = []
for path in valid_paths:
rel_path = path.relative_to(args.data_dir).as_posix()
metadata = {"filepath": rel_path}
label_data = label_map.get(rel_path)
if label_data:
metadata.update(label_data)
metadatas.append(metadata)
collection.upsert(
ids=valid_ids,
embeddings=embeddings,
metadatas=metadatas,
)
indexed_count += len(valid_paths)
progress.update(len(batch_paths))
progress.close()
print(
f"Indexed {indexed_count} images (scanned {total_images}) into collection "
f"'{args.collection}'."
)
if __name__ == "__main__":
main()