Jaywalker061707 commited on
Commit
0dd2819
·
verified ·
1 Parent(s): 52ea788

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -17
app.py CHANGED
@@ -1,10 +1,13 @@
1
  import gradio as gr
2
  from datasets import load_dataset
 
3
  import numpy as np
4
  from PIL import Image
5
  import torch
6
  from transformers import CLIPModel, CLIPProcessor
 
7
 
 
8
  def flux_to_gray(flux_array):
9
  a = np.array(flux_array, dtype=np.float32)
10
  a = np.squeeze(a)
@@ -12,32 +15,60 @@ def flux_to_gray(flux_array):
12
  axis = int(np.argmin(a.shape))
13
  a = np.nanmean(a, axis=axis)
14
  a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
15
- lo = np.nanpercentile(a, 1); hi = np.nanpercentile(a, 99)
 
16
  if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
17
  lo, hi = float(np.nanmin(a)), float(np.nanmax(a))
18
  norm = np.clip((a - lo) / (hi - lo + 1e-9), 0, 1)
19
  arr = (norm * 255).astype(np.uint8)
20
  return Image.fromarray(arr, mode="L")
21
 
 
22
  model_id = "openai/clip-vit-base-patch32"
23
  model = CLIPModel.from_pretrained(model_id)
24
  processor = CLIPProcessor.from_pretrained(model_id)
 
25
 
26
- def test_clip():
 
 
 
 
 
 
 
 
27
  ds = load_dataset("MultimodalUniverse/jwst", split="train", streaming=True)
28
- rec = next(iter(ds))
29
- pil = flux_to_gray(rec["image"]["flux"]).convert("RGB")
30
- with torch.no_grad():
31
- image_inputs = processor(images=pil, return_tensors="pt")
32
- image_feats = model.get_image_features(**image_inputs) # [1, 512]
33
- return pil, f"image_features shape: {tuple(image_feats.shape)}", f"object_id: {rec.get('object_id')}"
34
-
35
- demo = gr.Interface(
36
- fn=test_clip,
37
- inputs=None,
38
- outputs=[gr.Image(type="pil", label="Preview"),
39
- gr.Textbox(label="Shape", lines=1),
40
- gr.Textbox(label="Info", lines=1)],
41
- title="JWST CLIP embedding check (transformers)"
42
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  demo.launch()
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
+ from itertools import islice
4
  import numpy as np
5
  from PIL import Image
6
  import torch
7
  from transformers import CLIPModel, CLIPProcessor
8
+ import torch.nn.functional as F
9
 
10
+ # ---------- utils ----------
11
  def flux_to_gray(flux_array):
12
  a = np.array(flux_array, dtype=np.float32)
13
  a = np.squeeze(a)
 
15
  axis = int(np.argmin(a.shape))
16
  a = np.nanmean(a, axis=axis)
17
  a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
18
+ lo = np.nanpercentile(a, 1)
19
+ hi = np.nanpercentile(a, 99)
20
  if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
21
  lo, hi = float(np.nanmin(a)), float(np.nanmax(a))
22
  norm = np.clip((a - lo) / (hi - lo + 1e-9), 0, 1)
23
  arr = (norm * 255).astype(np.uint8)
24
  return Image.fromarray(arr, mode="L")
25
 
26
+ # ---------- model ----------
27
  model_id = "openai/clip-vit-base-patch32"
28
  model = CLIPModel.from_pretrained(model_id)
29
  processor = CLIPProcessor.from_pretrained(model_id)
30
+ model.eval()
31
 
32
+ # ---------- in-memory index ----------
33
+ INDEX = {
34
+ "feats": None, # torch.Tensor [N, 512]
35
+ "ids": [], # list[str]
36
+ "thumbs": [], # list[PIL.Image]
37
+ "bands": [] # list[str]
38
+ }
39
+
40
+ def build_index(n=200):
41
  ds = load_dataset("MultimodalUniverse/jwst", split="train", streaming=True)
42
+ feats, ids, thumbs, bands = [], [], [], []
43
+ for rec in islice(ds, int(n)):
44
+ pil = flux_to_gray(rec["image"]["flux"]).convert("RGB")
45
+ t = pil.copy(); t.thumbnail((128, 128))
46
+ with torch.no_grad():
47
+ inp = processor(images=pil, return_tensors="pt")
48
+ f = model.get_image_features(**inp) # [1, 512]
49
+ f = F.normalize(f, p=2, dim=-1)[0] # [512]
50
+ feats.append(f)
51
+ ids.append(str(rec.get("object_id")))
52
+ bands.append(str(rec["image"].get("band")))
53
+ thumbs.append(t)
54
+
55
+ if not feats:
56
+ return "No records indexed."
57
+
58
+ INDEX["feats"] = torch.stack(feats) # [N, 512]
59
+ INDEX["ids"] = ids
60
+ INDEX["thumbs"] = thumbs
61
+ INDEX["bands"] = bands
62
+ return f"Index built: {len(ids)} images."
63
+
64
+ # ---------- UI ----------
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown("JWST multimodal search — build the index")
67
+
68
+ n = gr.Slider(50, 1000, value=200, step=10, label="How many images to index")
69
+ build_btn = gr.Button("Build index")
70
+ status = gr.Textbox(label="Status", lines=2)
71
+
72
+ build_btn.click(build_index, inputs=n, outputs=status)
73
+
74
  demo.launch()