Spaces:
Sleeping
Sleeping
File size: 7,402 Bytes
3a167c5 b01d241 3a167c5 d390d1b 3a167c5 d390d1b 0636e54 d390d1b 0636e54 3a167c5 619dbf0 3a167c5 b01d241 3a167c5 b01d241 3a167c5 6e114a2 3a167c5 d390d1b 0636e54 d390d1b 0636e54 fe820c4 0636e54 d390d1b 0636e54 d390d1b 0636e54 fe820c4 0636e54 3a167c5 68f7921 3a167c5 68f7921 0636e54 68f7921 3a167c5 619dbf0 3a167c5 619dbf0 3a167c5 68f7921 3a167c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
#!/usr/bin/env python3
from __future__ import annotations
from contextlib import asynccontextmanager
import os
from pathlib import Path
from typing import Any, Dict, List
import chromadb
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from peft import PeftModel
from pydantic import BaseModel, Field
from transformers import SiglipModel, SiglipProcessor
from keyword_filters import (
CATEGORY_SYNONYMS,
COLOR_SYNONYMS,
VIBE_SYNONYMS,
extract_keywords,
)
DATA_DIR = (Path(__file__).resolve().parent / "data/2026-01-11").resolve()
class SearchRequest(BaseModel):
query: str = Field(..., min_length=1)
k: int = Field(10, ge=1, le=100)
def resolve_adapter_path(adapter_path: Path) -> Path:
if (adapter_path / "adapter_config.json").exists():
return adapter_path
candidate = adapter_path / "best_model"
if (candidate / "adapter_config.json").exists():
return candidate
return adapter_path
def extract_query_filters(query: str) -> Dict[str, List[str]]:
texts = [query]
return {
"categories": extract_keywords(texts, CATEGORY_SYNONYMS),
"colors": extract_keywords(texts, COLOR_SYNONYMS),
"vibes": extract_keywords(texts, VIBE_SYNONYMS),
}
def build_where_filter(
categories: List[str], colors: List[str], vibes: List[str]
) -> Dict[str, Any] | None:
clauses: List[Dict[str, Any]] = []
if categories:
clauses.append({"category": {"$in": categories}})
if colors:
clauses.append({"$and": [{f"color_{color}": True} for color in colors]})
if vibes:
clauses.append({"$and": [{f"vibe_{vibe}": True} for vibe in vibes]})
if not clauses:
return None
if len(clauses) == 1:
return clauses[0]
return {"$and": clauses}
def build_filter_candidates(filters: Dict[str, List[str]]) -> List[Dict[str, Any]]:
parts = {
"category": filters.get("categories") or [],
"color": filters.get("colors") or [],
"vibe": filters.get("vibes") or [],
}
candidates: List[Dict[str, Any]] = []
combos = [
("category", "color", "vibe"),
("category", "color"),
("category", "vibe"),
("color", "vibe"),
("category",),
("color",),
("vibe",),
]
for combo in combos:
if not all(parts[facet] for facet in combo):
continue
where_filter = build_where_filter(
parts["category"] if "category" in combo else [],
parts["color"] if "color" in combo else [],
parts["vibe"] if "vibe" in combo else [],
)
if where_filter:
candidates.append(where_filter)
return candidates
@asynccontextmanager
async def lifespan(app: FastAPI):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_id = "google/siglip-base-patch16-256-multilingual"
adapter_path = resolve_adapter_path(Path("outputs/ko-clip-lora"))
print("Loading SigLIP + LoRA model...")
base_model = SiglipModel.from_pretrained(base_model_id)
model = PeftModel.from_pretrained(base_model, str(adapter_path))
processor = SiglipProcessor.from_pretrained(base_model_id)
model.to(device)
model.eval()
client = chromadb.PersistentClient(path="chroma_db")
collection = client.get_or_create_collection(
name="maple_items",
metadata={"hnsw:space": "cosine"},
)
app.state.device = device
app.state.model = model
app.state.processor = processor
app.state.collection = collection
yield
app = FastAPI(lifespan=lifespan)
allowed_origins_env = os.getenv("ALLOWED_ORIGINS")
if allowed_origins_env:
allowed_origins = [
origin.strip()
for origin in allowed_origins_env.split(",")
if origin.strip()
]
else:
allowed_origins = [
"http://localhost:5173",
"http://127.0.0.1:5173",
]
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
if DATA_DIR.exists():
app.mount("/static/images", StaticFiles(directory=str(DATA_DIR)), name="images")
else:
print(f"Warning: static images directory not found: {DATA_DIR}")
@app.get("/")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.post("/search")
def search(payload: SearchRequest) -> Dict[str, Any]:
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty.")
model: SiglipModel = app.state.model
processor: SiglipProcessor = app.state.processor
device: torch.device = app.state.device
collection = app.state.collection
with torch.inference_mode():
text_inputs = processor(text=[query], return_tensors="pt", padding=True)
text_inputs = {key: value.to(device) for key, value in text_inputs.items()}
text_embeds = model.get_text_features(**text_inputs)
text_embeds = F.normalize(text_embeds, dim=-1)
query_embedding = text_embeds[0].detach().cpu().tolist()
filter_parts = extract_query_filters(query)
where_candidates = build_filter_candidates(filter_parts)
results = None
for where_filter in where_candidates:
try:
results = collection.query(
query_embeddings=[query_embedding],
n_results=payload.k,
where=where_filter,
include=["distances", "metadatas"],
)
except Exception as exc: # noqa: BLE001
print(f"Filtered query failed ({exc}); trying less strict.")
results = None
continue
if results and results.get("ids") and results["ids"][0]:
break
if not results or not results.get("ids") or not results["ids"][0]:
results = collection.query(
query_embeddings=[query_embedding],
n_results=payload.k,
include=["distances", "metadatas"],
)
ids: List[str] = results.get("ids", [[]])[0]
distances: List[float] = results.get("distances", [[]])[0]
metadatas: List[Dict[str, Any]] = results.get("metadatas", [[]])[0]
response_items = []
for item_id, distance, metadata in zip(ids, distances, metadatas):
filepath = ""
item_name = ""
label_ko = ""
if metadata:
filepath = metadata.get("filepath", "")
item_name = metadata.get("item_name", "") or ""
label_ko = metadata.get("label_ko") or metadata.get("label") or ""
if not item_name and filepath:
item_name = Path(filepath).stem
image_url = f"/static/images/{filepath}" if filepath else ""
similarity = max(0.0, 1.0 - distance) if distance is not None else 0.0
response_items.append(
{
"id": item_id,
"filepath": filepath,
"distance": distance,
"similarity": similarity,
"image_url": image_url,
"item_name": item_name,
"label_ko": label_ko,
"label": label_ko,
}
)
return {
"query": query,
"k": payload.k,
"results": response_items,
}
|