File size: 6,510 Bytes
703b5b1 3d7ad32 0647956 703b5b1 3d7ad32 703b5b1 3d7ad32 703b5b1 0647956 703b5b1 0647956 703b5b1 3d7ad32 0647956 3d7ad32 703b5b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """
space/_glb.py
-------------
Builds a GLB mesh (sphere per point) from UMAP coords for gr.Model3D.
PointCloud primitives render at 1px in Three.js regardless of scale;
small spheres give controllable apparent size.
"""
from __future__ import annotations
import os
import json
import struct
import tempfile
from pathlib import Path
import numpy as np
def _gradio_tmp() -> str:
# Gradio 5 only serves files inside its own temp dir; /tmp/tmpXXX.glb is
# outside it and returns 403. Match the same path Gradio uses internally.
d = os.environ.get("GRADIO_TEMP_DIR") or str(
(Path(tempfile.gettempdir()) / "gradio").resolve()
)
os.makedirs(d, exist_ok=True)
return d
_COLORS_RGB: list[tuple[int, int, int]] = [
(230, 237, 243), # student — #e6edf3
(124, 58, 237), # teacher0 — #7c3aed
( 6, 182, 212), # teacher1 — #06b6d4
(245, 158, 11), # teacher2 — #f59e0b
( 52, 211, 153), # teacher3 — #34d399
(244, 114, 182), # teacher4 — #f472b6
]
_PROBE_COLOR = (255, 255, 255)
def _hex(r: int, g: int, b: int) -> str:
return f"#{r:02x}{g:02x}{b:02x}"
def _inject_material(glb: bytes) -> bytes:
"""Add a matte PBR material to every primitive in a GLB.
trimesh exports vertex-colored meshes with POSITION + COLOR_0 but no
material, so model-viewer (the PBR renderer behind gr.Model3D) falls back
to the glTF default material (metallicFactor=1, roughnessFactor=1). Under
model-viewer's neutral environment a fully-metallic surface renders dark,
which is why the model looked gray. A matte (metallic=0) material lets the
per-vertex COLOR_0 show through, lit correctly.
"""
json_len = struct.unpack("<I", glb[12:16])[0]
json_bytes = glb[20:20 + json_len]
bin_chunk = glb[20 + json_len:] # keeps its own 8-byte chunk header
gltf = json.loads(json_bytes)
gltf.setdefault("materials", []).append({
"pbrMetallicRoughness": {
"baseColorFactor": [1, 1, 1, 1],
"metallicFactor": 0.0,
"roughnessFactor": 0.85,
},
"doubleSided": True,
})
midx = len(gltf["materials"]) - 1
for mesh in gltf.get("meshes", []):
for prim in mesh["primitives"]:
prim["material"] = midx
new_json = json.dumps(gltf, separators=(",", ":")).encode("utf-8")
new_json += b" " * ((-len(new_json)) % 4) # 4-byte align, pad with spaces
out = bytearray()
out += struct.pack("<III", 0x46546C67, 2, 12 + 8 + len(new_json) + len(bin_chunk))
out += struct.pack("<II", len(new_json), 0x4E4F534A) # JSON chunk header
out += new_json
out += bin_chunk
return bytes(out)
def build_glb(
viz: dict,
coords3d: "np.ndarray | None",
probe_points: list[dict],
) -> str | None:
"""Return path to a temporary .glb with one small sphere per embedding point."""
if coords3d is None or len(coords3d) == 0 or not viz.get("model_names"):
return None
import trimesh
model_names = viz["model_names"]
labels = np.array(viz["labels"])
# Adaptive radius: 1.8 % of the data bounding-box diagonal
span = float(np.linalg.norm(coords3d.max(axis=0) - coords3d.min(axis=0)))
radius = max(span * 0.018, 0.04)
# Build template sphere once, scale per group
tpl = trimesh.creation.icosphere(subdivisions=1, radius=1.0)
tpl_v = tpl.vertices.astype(np.float64) # (42, 3)
tpl_f = tpl.faces # (80, 3)
n_v = len(tpl_v)
all_verts : list[np.ndarray] = []
all_faces : list[np.ndarray] = []
all_colors : list[np.ndarray] = []
offset = 0
def _add_group(pts: np.ndarray, rgb: tuple[int, int, int], r: float) -> None:
nonlocal offset
color = np.array([*rgb, 255], dtype=np.uint8)
for pt in pts:
all_verts.append(tpl_v * r + pt)
all_faces.append(tpl_f + offset)
all_colors.append(np.tile(color, (n_v, 1)))
offset += n_v
for i, name in enumerate(model_names):
mask = labels == name
if not mask.any():
continue
pts = coords3d[mask].astype(np.float64)
r = radius * (1.6 if name == "student" else 1.0)
_add_group(pts, _COLORS_RGB[i % len(_COLORS_RGB)], r)
if probe_points:
probe_pts = np.array([[p["x"], p["y"], p["z"]] for p in probe_points],
dtype=np.float64)
_add_group(probe_pts, _PROBE_COLOR, radius * 2.0)
if not all_verts:
return None
vertices = np.concatenate(all_verts, axis=0)
faces = np.concatenate(all_faces, axis=0)
colors = np.concatenate(all_colors, axis=0)
# vertex_colors in the constructor → COLOR_0 attribute on export.
mesh = trimesh.Trimesh(
vertices=vertices, faces=faces, vertex_colors=colors, process=False
)
_ = mesh.vertex_normals # force smooth normals so the PBR renderer can shade
glb_bytes = _inject_material(mesh.export(file_type="glb", include_normals=True))
tmp = tempfile.NamedTemporaryFile(
suffix=".glb", dir=_gradio_tmp(), delete=False
)
tmp.write(glb_bytes)
tmp.close()
return tmp.name
def build_legend_html(viz: dict) -> str:
"""Colored dot legend matching the GLB sphere colors."""
if not viz.get("model_names"):
return ""
items = []
for i, name in enumerate(viz["model_names"]):
r, g, b = _COLORS_RGB[i % len(_COLORS_RGB)]
dot_color = _hex(r, g, b)
is_student = name == "student"
label = "student — Qwen2.5-0.5B" if is_student else f"{name} — teacher"
size = "10px" if is_student else "8px"
items.append(
f'<div style="display:flex;align-items:center;gap:6px;">'
f'<div style="width:{size};height:{size};border-radius:50%;'
f'background:{dot_color};flex-shrink:0;"></div>'
f'<span style="font-size:11px;color:#8b949e;font-family:monospace;">{label}</span>'
f'</div>'
)
items.append(
'<div style="display:flex;align-items:center;gap:6px;">'
'<div style="width:8px;height:8px;border-radius:50%;background:#ffffff;flex-shrink:0;"></div>'
'<span style="font-size:11px;color:#8b949e;font-family:monospace;">probe — your input</span>'
'</div>'
)
return (
'<div style="display:flex;flex-wrap:wrap;gap:10px 18px;padding:8px 2px;">'
+ "".join(items)
+ '</div>'
)
|