maple-data / main.py
gyubin02's picture
keyword filter
d390d1b
#!/usr/bin/env python3
from __future__ import annotations
from contextlib import asynccontextmanager
import os
from pathlib import Path
from typing import Any, Dict, List
import chromadb
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from peft import PeftModel
from pydantic import BaseModel, Field
from transformers import SiglipModel, SiglipProcessor
from keyword_filters import (
CATEGORY_SYNONYMS,
COLOR_SYNONYMS,
VIBE_SYNONYMS,
extract_keywords,
)
DATA_DIR = (Path(__file__).resolve().parent / "data/2026-01-11").resolve()
class SearchRequest(BaseModel):
query: str = Field(..., min_length=1)
k: int = Field(10, ge=1, le=100)
def resolve_adapter_path(adapter_path: Path) -> Path:
if (adapter_path / "adapter_config.json").exists():
return adapter_path
candidate = adapter_path / "best_model"
if (candidate / "adapter_config.json").exists():
return candidate
return adapter_path
def extract_query_filters(query: str) -> Dict[str, List[str]]:
texts = [query]
return {
"categories": extract_keywords(texts, CATEGORY_SYNONYMS),
"colors": extract_keywords(texts, COLOR_SYNONYMS),
"vibes": extract_keywords(texts, VIBE_SYNONYMS),
}
def build_where_filter(
categories: List[str], colors: List[str], vibes: List[str]
) -> Dict[str, Any] | None:
clauses: List[Dict[str, Any]] = []
if categories:
clauses.append({"category": {"$in": categories}})
if colors:
clauses.append({"$and": [{f"color_{color}": True} for color in colors]})
if vibes:
clauses.append({"$and": [{f"vibe_{vibe}": True} for vibe in vibes]})
if not clauses:
return None
if len(clauses) == 1:
return clauses[0]
return {"$and": clauses}
def build_filter_candidates(filters: Dict[str, List[str]]) -> List[Dict[str, Any]]:
parts = {
"category": filters.get("categories") or [],
"color": filters.get("colors") or [],
"vibe": filters.get("vibes") or [],
}
candidates: List[Dict[str, Any]] = []
combos = [
("category", "color", "vibe"),
("category", "color"),
("category", "vibe"),
("color", "vibe"),
("category",),
("color",),
("vibe",),
]
for combo in combos:
if not all(parts[facet] for facet in combo):
continue
where_filter = build_where_filter(
parts["category"] if "category" in combo else [],
parts["color"] if "color" in combo else [],
parts["vibe"] if "vibe" in combo else [],
)
if where_filter:
candidates.append(where_filter)
return candidates
@asynccontextmanager
async def lifespan(app: FastAPI):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_id = "google/siglip-base-patch16-256-multilingual"
adapter_path = resolve_adapter_path(Path("outputs/ko-clip-lora"))
print("Loading SigLIP + LoRA model...")
base_model = SiglipModel.from_pretrained(base_model_id)
model = PeftModel.from_pretrained(base_model, str(adapter_path))
processor = SiglipProcessor.from_pretrained(base_model_id)
model.to(device)
model.eval()
client = chromadb.PersistentClient(path="chroma_db")
collection = client.get_or_create_collection(
name="maple_items",
metadata={"hnsw:space": "cosine"},
)
app.state.device = device
app.state.model = model
app.state.processor = processor
app.state.collection = collection
yield
app = FastAPI(lifespan=lifespan)
allowed_origins_env = os.getenv("ALLOWED_ORIGINS")
if allowed_origins_env:
allowed_origins = [
origin.strip()
for origin in allowed_origins_env.split(",")
if origin.strip()
]
else:
allowed_origins = [
"http://localhost:5173",
"http://127.0.0.1:5173",
]
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
if DATA_DIR.exists():
app.mount("/static/images", StaticFiles(directory=str(DATA_DIR)), name="images")
else:
print(f"Warning: static images directory not found: {DATA_DIR}")
@app.get("/")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.post("/search")
def search(payload: SearchRequest) -> Dict[str, Any]:
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty.")
model: SiglipModel = app.state.model
processor: SiglipProcessor = app.state.processor
device: torch.device = app.state.device
collection = app.state.collection
with torch.inference_mode():
text_inputs = processor(text=[query], return_tensors="pt", padding=True)
text_inputs = {key: value.to(device) for key, value in text_inputs.items()}
text_embeds = model.get_text_features(**text_inputs)
text_embeds = F.normalize(text_embeds, dim=-1)
query_embedding = text_embeds[0].detach().cpu().tolist()
filter_parts = extract_query_filters(query)
where_candidates = build_filter_candidates(filter_parts)
results = None
for where_filter in where_candidates:
try:
results = collection.query(
query_embeddings=[query_embedding],
n_results=payload.k,
where=where_filter,
include=["distances", "metadatas"],
)
except Exception as exc: # noqa: BLE001
print(f"Filtered query failed ({exc}); trying less strict.")
results = None
continue
if results and results.get("ids") and results["ids"][0]:
break
if not results or not results.get("ids") or not results["ids"][0]:
results = collection.query(
query_embeddings=[query_embedding],
n_results=payload.k,
include=["distances", "metadatas"],
)
ids: List[str] = results.get("ids", [[]])[0]
distances: List[float] = results.get("distances", [[]])[0]
metadatas: List[Dict[str, Any]] = results.get("metadatas", [[]])[0]
response_items = []
for item_id, distance, metadata in zip(ids, distances, metadatas):
filepath = ""
item_name = ""
label_ko = ""
if metadata:
filepath = metadata.get("filepath", "")
item_name = metadata.get("item_name", "") or ""
label_ko = metadata.get("label_ko") or metadata.get("label") or ""
if not item_name and filepath:
item_name = Path(filepath).stem
image_url = f"/static/images/{filepath}" if filepath else ""
similarity = max(0.0, 1.0 - distance) if distance is not None else 0.0
response_items.append(
{
"id": item_id,
"filepath": filepath,
"distance": distance,
"similarity": similarity,
"image_url": image_url,
"item_name": item_name,
"label_ko": label_ko,
"label": label_ko,
}
)
return {
"query": query,
"k": payload.k,
"results": response_items,
}