"""Runtime model for the Style Tagger Space (nearest-neighbor over training set). - Embeds query image with CLIP (use_fast=True for consistency). - Projects with your trained StyleProjector. - Looks up nearest neighbors in a retrieval bundle (FAISS if available). - Tallies style tags from neighbors and returns {tag: score in [0,1]}. Expected files next to app.py: style_projector_and_topic_table.safetensors style_topic_meta.json style_vocab.json retrieval_bundle/ ├─ vectors.faiss (optional but recommended) ├─ tag_ids_concat.npy ├─ tag_offsets.npy └─ image_ids.npy (optional, for debugging) Notes: - If FAISS index is missing but you also ship vectors.npy (optional), we fall back to NumPy top‑K. - If neither is present, we return a stable dummy output so UI stays responsive. """ from __future__ import annotations import os, json from typing import Dict, List, Optional, Tuple import numpy as np from PIL import Image import torch import torch.nn as nn from safetensors.torch import load_file as load_safetensors from transformers import CLIPModel, CLIPProcessor # Try FAISS (optional at runtime) try: import faiss # type: ignore except Exception: faiss = None # ------------------- # Globals # ------------------- _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") _PROJECTOR: Optional[nn.Module] = None _CLIP_MODEL: Optional[CLIPModel] = None _CLIP_PROCESSOR: Optional[CLIPProcessor] = None _VOCAB: Optional[List[str]] = None # Retrieval bundle _INDEX: Optional[object] = None # faiss index if available _VEC_NP: Optional[np.ndarray] = None # fallback: raw vectors (optional file vectors.npy) _TAGS_CONCAT: Optional[np.ndarray] = None _TAGS_OFFSETS: Optional[np.ndarray] = None _IMAGE_IDS: Optional[np.ndarray] = None # filenames aligned to index rows _READY = False # How many neighbors to tally by default K_DEFAULT = int(os.getenv("STYLE_K", "30")) # ------------------- # Model definition (matches training projector) # ------------------- class StyleProjector(nn.Module): def __init__(self, d_in: int, d_out: int, use_layer_norm: bool = True): super().__init__() blocks = [nn.Linear(d_in, 1024), nn.GELU()] if use_layer_norm: blocks.append(nn.LayerNorm(1024)) blocks += [nn.Dropout(0.0), nn.Linear(1024, d_out)] self.net = nn.Sequential(*blocks) def forward(self, x: torch.Tensor) -> torch.Tensor: z = self.net(x) return nn.functional.normalize(z, dim=-1) # ------------------- # Helpers # ------------------- def _embed_image(image: Image.Image) -> torch.Tensor: """CLIP-embed the (optionally cropped) image and L2-normalize to unit length.""" # Crop 1/8 top & bottom to match preprocessing, controlled via env var if os.getenv("STYLE_CROP_TB", "1") not in ("0", "false", "False"): w, h = image.size dy = h // 8 if dy > 0: image = image.crop((0, dy, w, h - dy)) inputs = _CLIP_PROCESSOR(images=[image.convert("RGB")], return_tensors="pt").to(_DEVICE) with torch.no_grad(): feats = _CLIP_MODEL.get_image_features(**inputs) feats = feats / feats.norm(dim=-1, keepdim=True) return feats # [1, d_in] def _load_retrieval_bundle(bundle_dir: str) -> None: global _INDEX, _VEC_NP, _TAGS_CONCAT, _TAGS_OFFSETS, _IMAGE_IDS # Required tag tables _TAGS_CONCAT = np.load(os.path.join(bundle_dir, "tag_ids_concat.npy")) _TAGS_OFFSETS = np.load(os.path.join(bundle_dir, "tag_offsets.npy")) # Filenames (optional but recommended for diagnostics) img_ids_path = os.path.join(bundle_dir, "image_ids.npy") _IMAGE_IDS = np.load(img_ids_path) if os.path.exists(img_ids_path) else None # Preferred: FAISS index index_path = os.path.join(bundle_dir, "vectors.faiss") if faiss is not None and os.path.exists(index_path): _INDEX = faiss.read_index(index_path) else: _INDEX = None # Optional fallback: raw vectors file if you ship it vec_path = os.path.join(bundle_dir, "vectors.npy") if os.path.exists(vec_path): _VEC_NP = np.load(vec_path).astype(np.float32) # expected L2-normalized else: _VEC_NP = None # ------------------- # Public API # ------------------- def load(): """Load CLIP, projector weights, vocab, and retrieval bundle if present. This function is tolerant: it will not raise if files are missing; _READY will stay False. """ global _PROJECTOR, _CLIP_MODEL, _CLIP_PROCESSOR, _VOCAB, _READY try: # ---- Metadata ---- with open("style_topic_meta.json", "r", encoding="utf-8") as f: meta = json.load(f) proj_dim = int(meta["proj_dim"]) use_ln = bool(meta.get("use_layer_norm", True)) # ---- CLIP (use_fast=True for consistency; requires torchvision) ---- model_id = os.getenv("CLIP_MODEL_ID", "openai/clip-vit-base-patch32") _CLIP_MODEL = CLIPModel.from_pretrained(model_id).to(_DEVICE) _CLIP_PROCESSOR = CLIPProcessor.from_pretrained(model_id, use_fast=True) _CLIP_MODEL.eval() # ---- Projector ---- d_in = int(_CLIP_MODEL.config.projection_dim) _PROJECTOR = StyleProjector(d_in, proj_dim, use_layer_norm=use_ln).to(_DEVICE) tensors = load_safetensors("style_projector_and_topic_table.safetensors") with torch.no_grad(): _PROJECTOR.net[0].weight.copy_(tensors["projector.net.0.weight"]) _PROJECTOR.net[0].bias.copy_(tensors["projector.net.0.bias"]) if use_ln: _PROJECTOR.net[2].weight.copy_(tensors["projector.net.ln.weight"]) _PROJECTOR.net[2].bias.copy_(tensors["projector.net.ln.bias"]) last = _PROJECTOR.net[4] else: last = _PROJECTOR.net[3] last.weight.copy_(tensors["projector.net.last.weight"]) last.bias.copy_(tensors["projector.net.last.bias"]) _PROJECTOR.eval() # ---- Vocab ---- with open("style_vocab.json", "r", encoding="utf-8") as f: _VOCAB = json.load(f) # ---- Retrieval bundle (optional but recommended) ---- bundle_dir = os.getenv("RETRIEVAL_DIR", os.path.join(os.getcwd(), "retrieval_bundle")) if os.path.isdir(bundle_dir): _load_retrieval_bundle(bundle_dir) _READY = True except FileNotFoundError as e: print("model.load(): missing file (skeleton mode)", e) _READY = False def predict(image: Image.Image, k: Optional[int] = None) -> Tuple[Dict[str, float], List[Dict[str, object]], Dict[str, int]]: """Return (scores_norm, neighbors_detailed, counts_raw) - scores_norm: {style_tag: score in [0,1]} (counts normalized by max count) - neighbors_detailed: [{"filename": str, "similarity": float, "distance": float, "styles": [str, ...]}, ...] - counts_raw: {style_tag: int} (exact tallies used to rank) """ global _READY, _INDEX, _VEC_NP, _TAGS_CONCAT, _TAGS_OFFSETS, _VOCAB, _IMAGE_IDS if not _READY or _VOCAB is None: fallback = _VOCAB[:10] if _VOCAB else ["watercolor","oil painting","pixel art","sketch","digital painting"] return ({tag: 0.0 for tag in fallback}, [], {tag: 0 for tag in fallback}) # 1) Embed + project query with torch.no_grad(): q_clip = _embed_image(image) # [1, d_in] q_proj = _PROJECTOR(q_clip).cpu().numpy()[0] # [d_proj], L2-normed # 2) Nearest neighbors (and similarities) K = k or K_DEFAULT if _INDEX is not None and faiss is not None: q = q_proj[np.newaxis, :].astype(np.float32) D, I = _INDEX.search(q, K) # D: sims, I: indices nbrs = I[0] sims = D[0] elif _VEC_NP is not None: sims_all = _VEC_NP @ q_proj # cosine via dot, [N] if K < sims_all.shape[0]: idx = np.argpartition(-sims_all, K)[:K] order = np.argsort(-sims_all[idx]) nbrs = idx[order] sims = sims_all[nbrs] else: order = np.argsort(-sims_all) nbrs = order[:K] sims = sims_all[nbrs] else: # No retrieval table present fallback = {tag: 0.0 for tag in _VOCAB[:10]} return (fallback, [], {k: 0 for k in fallback}) # Build neighbor detail list with filenames, sims, distances, and neighbor styles neighbors: List[Dict[str, object]] = [] for i, s in zip(nbrs, sims): # neighbor styles from ragged arrays start, end = int(_TAGS_OFFSETS[i]), int(_TAGS_OFFSETS[i + 1]) tag_ids_i = _TAGS_CONCAT[start:end] styles_i = [_VOCAB[int(tid)] for tid in tag_ids_i] if len(tag_ids_i) > 0 else [] # filename (if available) fname = str(_IMAGE_IDS[i]) if _IMAGE_IDS is not None else str(int(i)) neighbors.append({ "filename": fname, "similarity": float(s), "distance": float(1.0 - s), # cosine distance since vectors are L2-normalized "styles": styles_i, }) # 3) Tally neighbor style tags → counts_raw if len(nbrs) == 0: fallback = {tag: 0.0 for tag in _VOCAB[:10]} return (fallback, neighbors, {k: 0 for k in fallback}) # concatenate tags across neighbors and bincount tag_ids_all = np.concatenate([ _TAGS_CONCAT[int(_TAGS_OFFSETS[i]):int(_TAGS_OFFSETS[i + 1])] for i in nbrs ]) if len(nbrs) else np.array([], dtype=np.int32) counts = np.bincount(tag_ids_all, minlength=len(_VOCAB)) if tag_ids_all.size else np.zeros(len(_VOCAB), dtype=np.int64) counts_raw: Dict[str, int] = { _VOCAB[i]: int(counts[i]) for i in np.nonzero(counts)[0] } # 4) Normalize counts → scores_norm in [0,1] maxc = float(counts.max()) if counts.size else 0.0 if maxc <= 0: return ({tag: 0.0 for tag in _VOCAB[:10]}, neighbors, counts_raw) scores = counts / maxc scores_norm: Dict[str, float] = { _VOCAB[i]: float(scores[i]) for i in np.nonzero(counts)[0] } return (scores_norm, neighbors, counts_raw)