Spaces:
Running on Zero
Running on Zero
Update initial order of categories displayed in UI and port improvements from the github repo UI
ab5ddbe | import torch | |
| torch.set_grad_enabled(False) | |
| import base64 | |
| import io | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import spaces | |
| import torchvision.transforms.v2 as v2 | |
| from fastapi.requests import Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from gradio import Server | |
| from PIL import Image | |
| from inference_tagger_standalone import ( | |
| PATCH_SIZE, | |
| Tagger, | |
| _IMAGENET_MEAN, | |
| _IMAGENET_STD, | |
| _open_image, | |
| _snap, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Model download + init | |
| # --------------------------------------------------------------------------- | |
| os.system( | |
| "wget -nv https://huggingface.co/lodestones/tagger-experiment/resolve/main/tagger_proto.safetensors" | |
| ) | |
| _VOCAB_PATH = "./tagger_vocab_with_categories.json" | |
| model = Tagger( | |
| checkpoint_path="./tagger_proto.safetensors", | |
| vocab_path=_VOCAB_PATH, | |
| max_size=1024, | |
| ) | |
| with open(_VOCAB_PATH) as f: | |
| _tag2category: dict[str, int] = json.load(f).get("tag2category", {}) | |
| # --------------------------------------------------------------------------- | |
| # Category metadata (mirrors tagger_ui_server.py) | |
| # --------------------------------------------------------------------------- | |
| _CAT_OFFSET = 1 | |
| CATEGORY_META: dict[int, dict] = { | |
| 0: {"name": "unassigned", "color": "#6b7280", "display_order": 9}, | |
| 1: {"name": "general", "color": "#4ade80", "display_order": 4}, | |
| 2: {"name": "artist", "color": "#f472b6", "display_order": 0}, | |
| 3: {"name": "contributor", "color": "#a78bfa", "display_order": 7}, | |
| 4: {"name": "copyright", "color": "#fb923c", "display_order": 1}, | |
| 5: {"name": "character", "color": "#60a5fa", "display_order": 2}, | |
| 6: {"name": "species", "color": "#facc15", "display_order": 3}, | |
| 7: {"name": "disambiguation", "color": "#94a3b8", "display_order": 8}, | |
| 8: {"name": "meta", "color": "#e2e8f0", "display_order": 6}, | |
| 9: {"name": "lore", "color": "#f87171", "display_order": 5}, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _resolve_image_source(image: Any) -> str: | |
| """Normalise the image argument from Gradio. | |
| Gradio passes uploaded files as a dict: | |
| {"path": "/tmp/gradio/...", "orig_name": "...", "url": "...", ...} | |
| URL strings and local paths are passed as plain str. | |
| """ | |
| if isinstance(image, dict): | |
| return image.get("path") or image.get("url") or image["orig_name"] | |
| return str(image) | |
| def _hex_to_rgb(hex_color: str) -> tuple[float, float, float]: | |
| """'#rrggbb' → (r, g, b) each in [0, 1].""" | |
| h = hex_color.lstrip("#") | |
| return tuple(int(h[i : i + 2], 16) / 255.0 for i in (0, 2, 4)) | |
| def _preprocess(img: Image.Image, max_size: int) -> torch.Tensor: | |
| """Resize + ImageNet-normalise → [1, 3, H, W] float32 CPU tensor.""" | |
| w, h = img.size | |
| scale = min(1.0, max_size / max(w, h)) | |
| new_w = _snap(round(w * scale), PATCH_SIZE) | |
| new_h = _snap(round(h * scale), PATCH_SIZE) | |
| return v2.Compose( | |
| [ | |
| v2.Resize((new_h, new_w), interpolation=v2.InterpolationMode.LANCZOS), | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD), | |
| ] | |
| )(img).unsqueeze(0) | |
| def _postprocess(logits: torch.Tensor, floor: float) -> str: | |
| """sigmoid → filter → sort → build category buckets → JSON string.""" | |
| scores = torch.sigmoid(logits) | |
| indices = (scores >= floor).nonzero(as_tuple=True)[0] | |
| values = scores[indices] | |
| order = values.argsort(descending=True) | |
| indices = indices[order] | |
| values = values[order] | |
| by_category: dict[int, list] = {} | |
| all_tags: list[dict] = [] | |
| for i, v in zip(indices.tolist(), values.tolist()): | |
| tag = model.idx2tag[i] | |
| cat = _tag2category.get(tag, -1) + _CAT_OFFSET | |
| item = {"tag": tag, "score": round(v, 4), "category": cat} | |
| all_tags.append(item) | |
| by_category.setdefault(cat, []).append(item) | |
| categories = [] | |
| for cat_id in sorted(by_category.keys(), key=lambda cid: CATEGORY_META.get(cid, {}).get("display_order", cid)): | |
| meta = CATEGORY_META.get(cat_id, {"name": str(cat_id), "color": "#6b7280"}) | |
| categories.append( | |
| { | |
| "id": cat_id, | |
| "name": meta["name"], | |
| "color": meta["color"], | |
| "tags": by_category[cat_id], | |
| } | |
| ) | |
| return json.dumps( | |
| {"tags": all_tags, "categories": categories, "count": len(all_tags)} | |
| ) | |
| def _pil_to_base64(img: Image.Image) -> str: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode() | |
| def _build_custom_pca( | |
| proj_norm: np.ndarray, h_p: int, w_p: int, color1: str, color2: str, color3: str | |
| ) -> Image.Image: | |
| """ | |
| Blend the three normalised PC channels using user-chosen colours. | |
| For each patch: output = PC1_val * color1_rgb | |
| + PC2_val * color2_rgb | |
| + PC3_val * color3_rgb | |
| Result is divided by the maximum possible sum (sum of the three | |
| colour magnitudes) so the output stays in [0, 1], then clamped. | |
| """ | |
| c1 = np.array(_hex_to_rgb(color1), dtype=np.float32) | |
| c2 = np.array(_hex_to_rgb(color2), dtype=np.float32) | |
| c3 = np.array(_hex_to_rgb(color3), dtype=np.float32) | |
| # proj_norm: [N, 3], values in [0, 1] | |
| blended = ( | |
| proj_norm[:, 0:1] * c1 + proj_norm[:, 1:2] * c2 + proj_norm[:, 2:3] * c3 | |
| ) # [N, 3] | |
| # normalise so the brightest patch reaches full intensity | |
| mx = blended.max() | |
| if mx > 0: | |
| blended /= mx | |
| rgb = blended.reshape(h_p, w_p, 3) | |
| patch_img = Image.fromarray((rgb * 255).clip(0, 255).astype("uint8"), "RGB") | |
| return patch_img.resize( | |
| (w_p * PATCH_SIZE, h_p * PATCH_SIZE), resample=Image.NEAREST | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # GPU-isolated helpers | |
| # --------------------------------------------------------------------------- | |
| def _gpu_extract_descriptor(pixel_values: torch.Tensor) -> np.ndarray: | |
| """Extract the FEATURE_DIM=6400 image descriptor via forward_embedding. | |
| Returns a [6400] float32 numpy array on CPU. | |
| """ | |
| pv = pixel_values.to(model.device) | |
| with ( | |
| torch.no_grad(), | |
| torch.autocast(device_type=model.device.type, dtype=model.dtype), | |
| ): | |
| features = model.model.forward_embedding(pv) # [1, 6400] | |
| return features[0].cpu().numpy() # [6400] | |
| def _gpu_infer(pixel_values: torch.Tensor) -> torch.Tensor: | |
| """Move tensor to device, run model forward, return CPU logits.""" | |
| pv = pixel_values.to(model.device) | |
| with ( | |
| torch.no_grad(), | |
| torch.autocast(device_type=model.device.type, dtype=model.dtype), | |
| ): | |
| logits = model.model(pv)[0] | |
| return logits.float().cpu() | |
| def _gpu_pca_extract(pixel_values: torch.Tensor) -> tuple: | |
| """Backbone forward → normalised PCA projection + grid dims. | |
| Returns (proj_norm_np [N,3], h_p, w_p) on CPU — no colour mapping here | |
| so callers can apply both full and custom colourings without re-running | |
| the backbone. | |
| """ | |
| pv = pixel_values.to(model.device) | |
| with ( | |
| torch.no_grad(), | |
| torch.autocast(device_type=model.device.type, dtype=model.dtype), | |
| ): | |
| patch_tokens, h_p, w_p = model.model.backbone.get_image_tokens(pv) | |
| tokens = patch_tokens[0].float() | |
| tokens_c = tokens - tokens.mean(dim=0, keepdim=True) | |
| _, _, Vt = torch.linalg.svd(tokens_c, full_matrices=False) | |
| projected = tokens_c @ Vt[:3].T # [N, 3] | |
| lo = projected.min(dim=0).values | |
| hi = projected.max(dim=0).values | |
| proj_norm = (projected - lo) / (hi - lo + 1e-8) # [N, 3] in [0,1] | |
| return proj_norm.cpu().numpy(), h_p, w_p | |
| # --------------------------------------------------------------------------- | |
| # gradio.Server | |
| # --------------------------------------------------------------------------- | |
| app = Server(title="DINOv3 Tagger UI") | |
| templates = Jinja2Templates(directory=Path(__file__).parent / "tagger_ui" / "templates") | |
| templates.env.filters["format_number"] = lambda v: f"{v:,}" | |
| # ---- Gradio API endpoints -------------------------------------------------- | |
| def get_tags(image: str, max_size: int = 1024, floor: float = 0.05) -> str: | |
| """Tag an image. Returns JSON: {tags, categories, count}.""" | |
| src = _resolve_image_source(image) | |
| img = _open_image(src) | |
| pv = _preprocess(img, max_size) | |
| logits = _gpu_infer(pv) | |
| return _postprocess(logits, floor) | |
| def get_pca( | |
| image: str, | |
| max_size: int = 1024, | |
| color1: str = "#ff0000", # PC1 colour for custom view | |
| color2: str = "#00ff00", # PC2 colour | |
| color3: str = "#0000ff", # PC3 colour | |
| ) -> str: | |
| """Return JSON: {full: <base64 PNG>, custom: <base64 PNG>}. | |
| full — PC1→R, PC2→G, PC3→B (fixed). | |
| custom — each PC channel tinted by the user-supplied hex colours, | |
| additively blended and normalised to [0,1]. | |
| """ | |
| src = _resolve_image_source(image) | |
| img = _open_image(src) | |
| pv = _preprocess(img, max_size) | |
| proj_norm, h_p, w_p = _gpu_pca_extract(pv) # GPU: backbone + PCA | |
| # full rainbow (CPU, fast) | |
| rgb_full = proj_norm.reshape(h_p, w_p, 3) | |
| full_patch = Image.fromarray((rgb_full * 255).clip(0, 255).astype("uint8"), "RGB") | |
| full_img = full_patch.resize( | |
| (w_p * PATCH_SIZE, h_p * PATCH_SIZE), resample=Image.NEAREST | |
| ) | |
| # custom colour blend (CPU, fast) | |
| custom_img = _build_custom_pca(proj_norm, h_p, w_p, color1, color2, color3) | |
| return json.dumps( | |
| { | |
| "full": _pil_to_base64(full_img), | |
| "custom": _pil_to_base64(custom_img), | |
| } | |
| ) | |
| def get_similarity(image_a: str, image_b: str, max_size: int = 1024) -> str: | |
| """Extract FEATURE_DIM=6400 descriptors for two images and return their | |
| cosine similarity. | |
| Returns JSON: | |
| { | |
| "score": float, # cosine similarity in [-1, 1] | |
| "desc_a": [6400 floats], # L2-normalised descriptor for image A | |
| "desc_b": [6400 floats], # L2-normalised descriptor for image B | |
| } | |
| """ | |
| src_a = _resolve_image_source(image_a) | |
| src_b = _resolve_image_source(image_b) | |
| img_a = _open_image(src_a) | |
| img_b = _open_image(src_b) | |
| pv_a = _preprocess(img_a, max_size) | |
| pv_b = _preprocess(img_b, max_size) | |
| # Run both through the backbone in separate GPU calls | |
| # (spaces.GPU does not support batching across different-sized tensors) | |
| feat_a = _gpu_extract_descriptor(pv_a) # [6400] | |
| feat_b = _gpu_extract_descriptor(pv_b) # [6400] | |
| # L2-normalise | |
| feat_a = feat_a / (np.linalg.norm(feat_a) + 1e-8) | |
| feat_b = feat_b / (np.linalg.norm(feat_b) + 1e-8) | |
| score = float(np.dot(feat_a, feat_b)) | |
| return json.dumps( | |
| { | |
| "score": round(score, 6), | |
| "desc_a": feat_a.tolist(), | |
| "desc_b": feat_b.tolist(), | |
| } | |
| ) | |
| # ---- FastAPI routes -------------------------------------------------------- | |
| async def index(request: Request): | |
| return templates.TemplateResponse( | |
| request, | |
| "index.html", | |
| { | |
| "num_tags": model.num_tags, | |
| "vocab_path": _VOCAB_PATH, | |
| "category_meta": CATEGORY_META, | |
| }, | |
| ) | |
| app.launch(ssr_mode=False) | |