import gradio as gr from datasets import load_dataset from itertools import islice import numpy as np from PIL import Image import torch from transformers import CLIPModel, CLIPProcessor import torch.nn.functional as F # ---------- utils ---------- def flux_to_gray(flux_array): a = np.array(flux_array, dtype=np.float32) a = np.squeeze(a) if a.ndim == 3: axis = int(np.argmin(a.shape)) a = np.nanmean(a, axis=axis) a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0) lo = np.nanpercentile(a, 1) hi = np.nanpercentile(a, 99) if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo: lo, hi = float(np.nanmin(a)), float(np.nanmax(a)) norm = np.clip((a - lo) / (hi - lo + 1e-9), 0, 1) arr = (norm * 255).astype(np.uint8) return Image.fromarray(arr, mode="L") # ---------- model ---------- model_id = "openai/clip-vit-base-patch32" model = CLIPModel.from_pretrained(model_id) processor = CLIPProcessor.from_pretrained(model_id) model.eval() # ---------- in-memory index ---------- INDEX = { "feats": None, # torch.Tensor [N, 512] "ids": [], # list[str] "thumbs": [], # list[PIL.Image] "bands": [] # list[str] } def build_index(n=200): ds = load_dataset("MultimodalUniverse/jwst", split="train", streaming=True) feats, ids, thumbs, bands = [], [], [], [] for rec in islice(ds, int(n)): pil = flux_to_gray(rec["image"]["flux"]).convert("RGB") t = pil.copy(); t.thumbnail((128, 128)) with torch.no_grad(): inp = processor(images=pil, return_tensors="pt") f = model.get_image_features(**inp) # [1, 512] f = F.normalize(f, p=2, dim=-1)[0] # [512] feats.append(f) ids.append(str(rec.get("object_id"))) bands.append(str(rec["image"].get("band"))) thumbs.append(t) if not feats: return "No records indexed." INDEX["feats"] = torch.stack(feats) # [N, 512] INDEX["ids"] = ids INDEX["thumbs"] = thumbs INDEX["bands"] = bands return f"Index built: {len(ids)} images." def search(text_query, image_query, k=5): if INDEX["feats"] is None: return [], "Build the index first." with torch.no_grad(): if text_query and str(text_query).strip(): inputs = processor(text=[str(text_query).strip()], return_tensors="pt") q = model.get_text_features(**inputs) # [1, 512] elif image_query is not None: pil = image_query.convert("RGB") inputs = processor(images=pil, return_tensors="pt") q = model.get_image_features(**inputs) # [1, 512] else: return [], "Enter text or upload an image." q = F.normalize(q, p=2, dim=-1)[0] # [512] sims = (INDEX["feats"] @ q).cpu() # [N] k = min(int(k), sims.shape[0]) topk = torch.topk(sims, k=k) items = [] for idx in topk.indices.tolist(): cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}" items.append((INDEX["thumbs"][idx], cap)) return items, f"Returned {k} results." # ---------- UI ---------- with gr.Blocks() as demo: gr.Markdown("JWST multimodal search — build the index") n = gr.Slider(50, 1000, value=200, step=10, label="How many images to index") build_btn = gr.Button("Build index") status = gr.Textbox(label="Status", lines=2) build_btn.click(build_index, inputs=n, outputs=status) demo.launch()