#!/usr/bin/env python3 from __future__ import annotations import argparse import json from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple, TypeVar import chromadb import torch import torch.nn.functional as F from peft import PeftModel from PIL import Image from tqdm import tqdm from transformers import SiglipModel, SiglipProcessor from keyword_filters import ( CATEGORY_SYNONYMS, COLOR_SYNONYMS, VIBE_SYNONYMS, extract_keywords, ) IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} T = TypeVar("T") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Index images with SigLIP + LoRA embeddings into ChromaDB." ) parser.add_argument( "--data-dir", type=Path, default=Path("data/2026-01-11"), help="Root directory containing images to index.", ) parser.add_argument( "--model-id", default="google/siglip-base-patch16-256-multilingual", help="Base SigLIP model ID.", ) parser.add_argument( "--adapter-path", type=Path, default=Path("outputs/ko-clip-lora"), help="Path to the LoRA adapter directory.", ) parser.add_argument( "--batch-size", type=int, default=32, help="Batch size for image embedding.", ) parser.add_argument( "--chroma-path", type=Path, default=Path("chroma_db"), help="Directory for ChromaDB persistent storage.", ) parser.add_argument( "--collection", default="maple_items", help="ChromaDB collection name.", ) parser.add_argument( "--labels-path", type=Path, default=None, help="Path to labels.jsonl (defaults to data-dir/labels/labels.jsonl).", ) return parser.parse_args() def resolve_adapter_path(adapter_path: Path) -> Path: if (adapter_path / "adapter_config.json").exists(): return adapter_path candidate = adapter_path / "best_model" if (candidate / "adapter_config.json").exists(): return candidate return adapter_path def find_images(root: Path) -> List[Path]: if not root.exists(): raise FileNotFoundError(f"Data directory not found: {root}") images = [ path for path in root.rglob("*") if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS ] return sorted(images) def batch_iter(items: List[T], batch_size: int) -> Iterable[List[T]]: for start in range(0, len(items), batch_size): yield items[start : start + batch_size] def build_ids(paths: List[Path]) -> List[str]: counts = {} ids: List[str] = [] for path in paths: stem = path.stem count = counts.get(stem, 0) ids.append(stem if count == 0 else f"{stem}_{count}") counts[stem] = count + 1 return ids def load_images( paths: List[Path], ids: List[str] ) -> Tuple[List[Image.Image], List[Path], List[str]]: images: List[Image.Image] = [] valid_paths: List[Path] = [] valid_ids: List[str] = [] for path, item_id in zip(paths, ids): try: with Image.open(path) as image: images.append(image.convert("RGB")) valid_paths.append(path) valid_ids.append(item_id) except Exception as exc: # noqa: BLE001 print(f"Skipping unreadable image: {path} ({exc})") return images, valid_paths, valid_ids def normalize_label(value: Optional[str]) -> Optional[str]: if value is None: return None if isinstance(value, str): trimmed = value.strip() return trimmed or None return str(value) def detect_category(texts: List[str]) -> Optional[str]: lowered_texts = [text.lower() for text in texts if text] for category, keywords in CATEGORY_SYNONYMS.items(): for keyword in keywords: keyword_lower = keyword.lower() if any(keyword_lower in text for text in lowered_texts): return category return None def collect_label_texts( item_name: Optional[str], label_ko: Optional[str], tags: List[str], query_variants: List[str], attributes: Dict[str, object], item_type_guess: Optional[str], ) -> List[str]: texts: List[str] = [] for value in (item_name, label_ko, item_type_guess): if value: texts.append(value) texts.extend(tag for tag in tags if tag) texts.extend(variant for variant in query_variants if variant) for value in attributes.values(): if isinstance(value, list): for entry in value: entry_norm = normalize_label(entry) if entry_norm: texts.append(entry_norm) else: entry_norm = normalize_label(value) if entry_norm: texts.append(entry_norm) return texts def load_labels(labels_path: Path) -> Dict[str, Dict[str, object]]: if not labels_path.exists(): print(f"Labels file not found, continuing without labels: {labels_path}") return {} label_map: Dict[str, Dict[str, object]] = {} with labels_path.open("r", encoding="utf-8") as file: for line_no, line in enumerate(file, start=1): line = line.strip() if not line: continue try: record = json.loads(line) except json.JSONDecodeError as exc: print(f"Skipping label line {line_no}: {exc}") continue image_path = record.get("image_path") if not image_path: continue item_name = normalize_label(record.get("item_name")) label_ko = normalize_label(record.get("label_ko")) tags = record.get("tags_ko") or [] tag_texts = [normalize_label(tag) for tag in tags if tag is not None] tag_texts = [tag for tag in tag_texts if tag] query_variants = record.get("query_variants_ko") or [] variant_texts = [ normalize_label(variant) for variant in query_variants if variant is not None ] variant_texts = [variant for variant in variant_texts if variant] attributes = record.get("attributes") or {} item_type_guess = normalize_label(attributes.get("item_type_guess")) if not item_name and not label_ko and not tag_texts: continue normalized_path = Path(str(image_path)).as_posix().lstrip("./") label_map[normalized_path] = {} if item_name: label_map[normalized_path]["item_name"] = item_name if label_ko: label_map[normalized_path]["label_ko"] = label_ko label_map[normalized_path]["label"] = label_ko texts = collect_label_texts( item_name, label_ko, tag_texts, variant_texts, attributes, item_type_guess, ) category = detect_category(texts) if category: label_map[normalized_path]["category"] = category colors = extract_keywords(texts, COLOR_SYNONYMS) if colors: for color in colors: label_map[normalized_path][f"color_{color}"] = True vibes = extract_keywords(texts, VIBE_SYNONYMS) if vibes: for vibe in vibes: label_map[normalized_path][f"vibe_{vibe}"] = True print(f"Loaded labels for {len(label_map)} images from {labels_path}") return label_map def main() -> None: args = parse_args() image_paths = find_images(args.data_dir) if not image_paths: print(f"No images found under {args.data_dir}") return ids = build_ids(image_paths) adapter_path = resolve_adapter_path(args.adapter_path) labels_path = args.labels_path or args.data_dir / "labels/labels.jsonl" label_map = load_labels(labels_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Loading model and processor...") base_model = SiglipModel.from_pretrained(args.model_id) model = PeftModel.from_pretrained(base_model, str(adapter_path)) processor = SiglipProcessor.from_pretrained(args.model_id) model.to(device) model.eval() client = chromadb.PersistentClient(path=str(args.chroma_path)) collection = client.get_or_create_collection( name=args.collection, metadata={"hnsw:space": "cosine"}, ) total_images = len(image_paths) progress = tqdm(total=total_images, desc="Indexing images", unit="img") indexed_count = 0 with torch.no_grad(): for batch_paths, batch_ids in zip( batch_iter(image_paths, args.batch_size), batch_iter(ids, args.batch_size), ): images, valid_paths, valid_ids = load_images(batch_paths, batch_ids) if not images: progress.update(len(batch_paths)) continue inputs = processor(images=images, return_tensors="pt") inputs = {key: value.to(device) for key, value in inputs.items()} embeds = model.get_image_features(**inputs) embeds = F.normalize(embeds, dim=-1) embeddings = embeds.detach().cpu().tolist() metadatas = [] for path in valid_paths: rel_path = path.relative_to(args.data_dir).as_posix() metadata = {"filepath": rel_path} label_data = label_map.get(rel_path) if label_data: metadata.update(label_data) metadatas.append(metadata) collection.upsert( ids=valid_ids, embeddings=embeddings, metadatas=metadatas, ) indexed_count += len(valid_paths) progress.update(len(batch_paths)) progress.close() print( f"Indexed {indexed_count} images (scanned {total_images}) into collection " f"'{args.collection}'." ) if __name__ == "__main__": main()