Spaces:
Sleeping
Sleeping
| """Runtime model for the Style Tagger Space (nearest-neighbor over training set). | |
| - Embeds query image with CLIP (use_fast=True for consistency). | |
| - Projects with your trained StyleProjector. | |
| - Looks up nearest neighbors in a retrieval bundle (FAISS if available). | |
| - Tallies style tags from neighbors and returns {tag: score in [0,1]}. | |
| Expected files next to app.py: | |
| style_projector_and_topic_table.safetensors | |
| style_topic_meta.json | |
| style_vocab.json | |
| retrieval_bundle/ | |
| ├─ vectors.faiss (optional but recommended) | |
| ├─ tag_ids_concat.npy | |
| ├─ tag_offsets.npy | |
| └─ image_ids.npy (optional, for debugging) | |
| Notes: | |
| - If FAISS index is missing but you also ship vectors.npy (optional), we fall back to NumPy top‑K. | |
| - If neither is present, we return a stable dummy output so UI stays responsive. | |
| """ | |
| from __future__ import annotations | |
| import os, json | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| from safetensors.torch import load_file as load_safetensors | |
| from transformers import CLIPModel, CLIPProcessor | |
| # Try FAISS (optional at runtime) | |
| try: | |
| import faiss # type: ignore | |
| except Exception: | |
| faiss = None | |
| # ------------------- | |
| # Globals | |
| # ------------------- | |
| _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _PROJECTOR: Optional[nn.Module] = None | |
| _CLIP_MODEL: Optional[CLIPModel] = None | |
| _CLIP_PROCESSOR: Optional[CLIPProcessor] = None | |
| _VOCAB: Optional[List[str]] = None | |
| # Retrieval bundle | |
| _INDEX: Optional[object] = None # faiss index if available | |
| _VEC_NP: Optional[np.ndarray] = None # fallback: raw vectors (optional file vectors.npy) | |
| _TAGS_CONCAT: Optional[np.ndarray] = None | |
| _TAGS_OFFSETS: Optional[np.ndarray] = None | |
| _IMAGE_IDS: Optional[np.ndarray] = None # filenames aligned to index rows | |
| _READY = False | |
| # How many neighbors to tally by default | |
| K_DEFAULT = int(os.getenv("STYLE_K", "30")) | |
| # ------------------- | |
| # Model definition (matches training projector) | |
| # ------------------- | |
| class StyleProjector(nn.Module): | |
| def __init__(self, d_in: int, d_out: int, use_layer_norm: bool = True): | |
| super().__init__() | |
| blocks = [nn.Linear(d_in, 1024), nn.GELU()] | |
| if use_layer_norm: | |
| blocks.append(nn.LayerNorm(1024)) | |
| blocks += [nn.Dropout(0.0), nn.Linear(1024, d_out)] | |
| self.net = nn.Sequential(*blocks) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| z = self.net(x) | |
| return nn.functional.normalize(z, dim=-1) | |
| # ------------------- | |
| # Helpers | |
| # ------------------- | |
| def _embed_image(image: Image.Image) -> torch.Tensor: | |
| """CLIP-embed the (optionally cropped) image and L2-normalize to unit length.""" | |
| # Crop 1/8 top & bottom to match preprocessing, controlled via env var | |
| if os.getenv("STYLE_CROP_TB", "1") not in ("0", "false", "False"): | |
| w, h = image.size | |
| dy = h // 8 | |
| if dy > 0: | |
| image = image.crop((0, dy, w, h - dy)) | |
| inputs = _CLIP_PROCESSOR(images=[image.convert("RGB")], return_tensors="pt").to(_DEVICE) | |
| with torch.no_grad(): | |
| feats = _CLIP_MODEL.get_image_features(**inputs) | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats # [1, d_in] | |
| def _load_retrieval_bundle(bundle_dir: str) -> None: | |
| global _INDEX, _VEC_NP, _TAGS_CONCAT, _TAGS_OFFSETS, _IMAGE_IDS | |
| # Required tag tables | |
| _TAGS_CONCAT = np.load(os.path.join(bundle_dir, "tag_ids_concat.npy")) | |
| _TAGS_OFFSETS = np.load(os.path.join(bundle_dir, "tag_offsets.npy")) | |
| # Filenames (optional but recommended for diagnostics) | |
| img_ids_path = os.path.join(bundle_dir, "image_ids.npy") | |
| _IMAGE_IDS = np.load(img_ids_path) if os.path.exists(img_ids_path) else None | |
| # Preferred: FAISS index | |
| index_path = os.path.join(bundle_dir, "vectors.faiss") | |
| if faiss is not None and os.path.exists(index_path): | |
| _INDEX = faiss.read_index(index_path) | |
| else: | |
| _INDEX = None | |
| # Optional fallback: raw vectors file if you ship it | |
| vec_path = os.path.join(bundle_dir, "vectors.npy") | |
| if os.path.exists(vec_path): | |
| _VEC_NP = np.load(vec_path).astype(np.float32) # expected L2-normalized | |
| else: | |
| _VEC_NP = None | |
| # ------------------- | |
| # Public API | |
| # ------------------- | |
| def load(): | |
| """Load CLIP, projector weights, vocab, and retrieval bundle if present. | |
| This function is tolerant: it will not raise if files are missing; _READY will stay False. | |
| """ | |
| global _PROJECTOR, _CLIP_MODEL, _CLIP_PROCESSOR, _VOCAB, _READY | |
| try: | |
| # ---- Metadata ---- | |
| with open("style_topic_meta.json", "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| proj_dim = int(meta["proj_dim"]) | |
| use_ln = bool(meta.get("use_layer_norm", True)) | |
| # ---- CLIP (use_fast=True for consistency; requires torchvision) ---- | |
| model_id = os.getenv("CLIP_MODEL_ID", "openai/clip-vit-base-patch32") | |
| _CLIP_MODEL = CLIPModel.from_pretrained(model_id).to(_DEVICE) | |
| _CLIP_PROCESSOR = CLIPProcessor.from_pretrained(model_id, use_fast=True) | |
| _CLIP_MODEL.eval() | |
| # ---- Projector ---- | |
| d_in = int(_CLIP_MODEL.config.projection_dim) | |
| _PROJECTOR = StyleProjector(d_in, proj_dim, use_layer_norm=use_ln).to(_DEVICE) | |
| tensors = load_safetensors("style_projector_and_topic_table.safetensors") | |
| with torch.no_grad(): | |
| _PROJECTOR.net[0].weight.copy_(tensors["projector.net.0.weight"]) | |
| _PROJECTOR.net[0].bias.copy_(tensors["projector.net.0.bias"]) | |
| if use_ln: | |
| _PROJECTOR.net[2].weight.copy_(tensors["projector.net.ln.weight"]) | |
| _PROJECTOR.net[2].bias.copy_(tensors["projector.net.ln.bias"]) | |
| last = _PROJECTOR.net[4] | |
| else: | |
| last = _PROJECTOR.net[3] | |
| last.weight.copy_(tensors["projector.net.last.weight"]) | |
| last.bias.copy_(tensors["projector.net.last.bias"]) | |
| _PROJECTOR.eval() | |
| # ---- Vocab ---- | |
| with open("style_vocab.json", "r", encoding="utf-8") as f: | |
| _VOCAB = json.load(f) | |
| # ---- Retrieval bundle (optional but recommended) ---- | |
| bundle_dir = os.getenv("RETRIEVAL_DIR", os.path.join(os.getcwd(), "retrieval_bundle")) | |
| if os.path.isdir(bundle_dir): | |
| _load_retrieval_bundle(bundle_dir) | |
| _READY = True | |
| except FileNotFoundError as e: | |
| print("model.load(): missing file (skeleton mode)", e) | |
| _READY = False | |
| def predict(image: Image.Image, k: Optional[int] = None) -> Tuple[Dict[str, float], List[Dict[str, object]], Dict[str, int]]: | |
| """Return (scores_norm, neighbors_detailed, counts_raw) | |
| - scores_norm: {style_tag: score in [0,1]} (counts normalized by max count) | |
| - neighbors_detailed: [{"filename": str, "similarity": float, "distance": float, "styles": [str, ...]}, ...] | |
| - counts_raw: {style_tag: int} (exact tallies used to rank) | |
| """ | |
| global _READY, _INDEX, _VEC_NP, _TAGS_CONCAT, _TAGS_OFFSETS, _VOCAB, _IMAGE_IDS | |
| if not _READY or _VOCAB is None: | |
| fallback = _VOCAB[:10] if _VOCAB else ["watercolor","oil painting","pixel art","sketch","digital painting"] | |
| return ({tag: 0.0 for tag in fallback}, [], {tag: 0 for tag in fallback}) | |
| # 1) Embed + project query | |
| with torch.no_grad(): | |
| q_clip = _embed_image(image) # [1, d_in] | |
| q_proj = _PROJECTOR(q_clip).cpu().numpy()[0] # [d_proj], L2-normed | |
| # 2) Nearest neighbors (and similarities) | |
| K = k or K_DEFAULT | |
| if _INDEX is not None and faiss is not None: | |
| q = q_proj[np.newaxis, :].astype(np.float32) | |
| D, I = _INDEX.search(q, K) # D: sims, I: indices | |
| nbrs = I[0] | |
| sims = D[0] | |
| elif _VEC_NP is not None: | |
| sims_all = _VEC_NP @ q_proj # cosine via dot, [N] | |
| if K < sims_all.shape[0]: | |
| idx = np.argpartition(-sims_all, K)[:K] | |
| order = np.argsort(-sims_all[idx]) | |
| nbrs = idx[order] | |
| sims = sims_all[nbrs] | |
| else: | |
| order = np.argsort(-sims_all) | |
| nbrs = order[:K] | |
| sims = sims_all[nbrs] | |
| else: | |
| # No retrieval table present | |
| fallback = {tag: 0.0 for tag in _VOCAB[:10]} | |
| return (fallback, [], {k: 0 for k in fallback}) | |
| # Build neighbor detail list with filenames, sims, distances, and neighbor styles | |
| neighbors: List[Dict[str, object]] = [] | |
| for i, s in zip(nbrs, sims): | |
| # neighbor styles from ragged arrays | |
| start, end = int(_TAGS_OFFSETS[i]), int(_TAGS_OFFSETS[i + 1]) | |
| tag_ids_i = _TAGS_CONCAT[start:end] | |
| styles_i = [_VOCAB[int(tid)] for tid in tag_ids_i] if len(tag_ids_i) > 0 else [] | |
| # filename (if available) | |
| fname = str(_IMAGE_IDS[i]) if _IMAGE_IDS is not None else str(int(i)) | |
| neighbors.append({ | |
| "filename": fname, | |
| "similarity": float(s), | |
| "distance": float(1.0 - s), # cosine distance since vectors are L2-normalized | |
| "styles": styles_i, | |
| }) | |
| # 3) Tally neighbor style tags → counts_raw | |
| if len(nbrs) == 0: | |
| fallback = {tag: 0.0 for tag in _VOCAB[:10]} | |
| return (fallback, neighbors, {k: 0 for k in fallback}) | |
| # concatenate tags across neighbors and bincount | |
| tag_ids_all = np.concatenate([ | |
| _TAGS_CONCAT[int(_TAGS_OFFSETS[i]):int(_TAGS_OFFSETS[i + 1])] for i in nbrs | |
| ]) if len(nbrs) else np.array([], dtype=np.int32) | |
| counts = np.bincount(tag_ids_all, minlength=len(_VOCAB)) if tag_ids_all.size else np.zeros(len(_VOCAB), dtype=np.int64) | |
| counts_raw: Dict[str, int] = { _VOCAB[i]: int(counts[i]) for i in np.nonzero(counts)[0] } | |
| # 4) Normalize counts → scores_norm in [0,1] | |
| maxc = float(counts.max()) if counts.size else 0.0 | |
| if maxc <= 0: | |
| return ({tag: 0.0 for tag in _VOCAB[:10]}, neighbors, counts_raw) | |
| scores = counts / maxc | |
| scores_norm: Dict[str, float] = { _VOCAB[i]: float(scores[i]) for i in np.nonzero(counts)[0] } | |
| return (scores_norm, neighbors, counts_raw) | |