Space / app.py
Jaywalker061707's picture
Update app.py
b8afcdb verified
raw
history blame
3.61 kB
import gradio as gr
from datasets import load_dataset
from itertools import islice
import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor
import torch.nn.functional as F
# ---------- utils ----------
def flux_to_gray(flux_array):
a = np.array(flux_array, dtype=np.float32)
a = np.squeeze(a)
if a.ndim == 3:
axis = int(np.argmin(a.shape))
a = np.nanmean(a, axis=axis)
a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
lo = np.nanpercentile(a, 1)
hi = np.nanpercentile(a, 99)
if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
lo, hi = float(np.nanmin(a)), float(np.nanmax(a))
norm = np.clip((a - lo) / (hi - lo + 1e-9), 0, 1)
arr = (norm * 255).astype(np.uint8)
return Image.fromarray(arr, mode="L")
# ---------- model ----------
model_id = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)
model.eval()
# ---------- in-memory index ----------
INDEX = {
"feats": None, # torch.Tensor [N, 512]
"ids": [], # list[str]
"thumbs": [], # list[PIL.Image]
"bands": [] # list[str]
}
def build_index(n=200):
ds = load_dataset("MultimodalUniverse/jwst", split="train", streaming=True)
feats, ids, thumbs, bands = [], [], [], []
for rec in islice(ds, int(n)):
pil = flux_to_gray(rec["image"]["flux"]).convert("RGB")
t = pil.copy(); t.thumbnail((128, 128))
with torch.no_grad():
inp = processor(images=pil, return_tensors="pt")
f = model.get_image_features(**inp) # [1, 512]
f = F.normalize(f, p=2, dim=-1)[0] # [512]
feats.append(f)
ids.append(str(rec.get("object_id")))
bands.append(str(rec["image"].get("band")))
thumbs.append(t)
if not feats:
return "No records indexed."
INDEX["feats"] = torch.stack(feats) # [N, 512]
INDEX["ids"] = ids
INDEX["thumbs"] = thumbs
INDEX["bands"] = bands
return f"Index built: {len(ids)} images."
def search(text_query, image_query, k=5):
if INDEX["feats"] is None:
return [], "Build the index first."
with torch.no_grad():
if text_query and str(text_query).strip():
inputs = processor(text=[str(text_query).strip()], return_tensors="pt")
q = model.get_text_features(**inputs) # [1, 512]
elif image_query is not None:
pil = image_query.convert("RGB")
inputs = processor(images=pil, return_tensors="pt")
q = model.get_image_features(**inputs) # [1, 512]
else:
return [], "Enter text or upload an image."
q = F.normalize(q, p=2, dim=-1)[0] # [512]
sims = (INDEX["feats"] @ q).cpu() # [N]
k = min(int(k), sims.shape[0])
topk = torch.topk(sims, k=k)
items = []
for idx in topk.indices.tolist():
cap = f"id: {INDEX['ids'][idx]} score: {float(sims[idx]):.3f} band: {INDEX['bands'][idx]}"
items.append((INDEX["thumbs"][idx], cap))
return items, f"Returned {k} results."
# ---------- UI ----------
with gr.Blocks() as demo:
gr.Markdown("JWST multimodal search β€” build the index")
n = gr.Slider(50, 1000, value=200, step=10, label="How many images to index")
build_btn = gr.Button("Build index")
status = gr.Textbox(label="Status", lines=2)
build_btn.click(build_index, inputs=n, outputs=status)
demo.launch()