toratal3 commited on
Commit
6843d1b
·
0 Parent(s):

Fix image encoder for Gradio 6.x: handle filepath/dict input types

Browse files
Files changed (2) hide show
  1. app.py +73 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SigLIP 2 Text & Image Encoder -- HuggingFace Space
3
+ Encodes text or image queries to 768-dim vectors for the Epstein photo search.
4
+
5
+ Model: google/siglip2-base-patch16-224
6
+ """
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ from transformers import AutoModel, AutoTokenizer, AutoProcessor
13
+
14
+ MODEL_NAME = "google/siglip2-base-patch16-224"
15
+
16
+ print(f"Loading {MODEL_NAME}...")
17
+ model = AutoModel.from_pretrained(MODEL_NAME).eval()
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
20
+ print(f"Model loaded. Text hidden size: {model.config.text_config.hidden_size}")
21
+
22
+ def encode(text: str) -> list:
23
+ inputs = tokenizer([text], return_tensors="pt", padding="max_length", max_length=64, truncation=True)
24
+ with torch.no_grad():
25
+ feats = model.text_model(**inputs).pooler_output
26
+ feats = F.normalize(feats, dim=-1)
27
+ return feats[0].tolist()
28
+
29
+ def encode_image(image) -> list:
30
+ print(f"encode_image called with type: {type(image)}")
31
+ if image is None:
32
+ raise gr.Error("No image provided")
33
+ # Gradio 5+/6+ may pass a filepath string or dict instead of PIL Image
34
+ if isinstance(image, str):
35
+ image = Image.open(image).convert("RGB")
36
+ elif isinstance(image, dict):
37
+ # Gradio FileData dict: {"path": "/tmp/...", "url": "...", ...}
38
+ path = image.get("path") or image.get("url")
39
+ if path and path.startswith("data:"):
40
+ import base64, io
41
+ header, data = path.split(",", 1)
42
+ image = Image.open(io.BytesIO(base64.b64decode(data))).convert("RGB")
43
+ elif path:
44
+ image = Image.open(path).convert("RGB")
45
+ else:
46
+ raise gr.Error(f"Cannot parse image dict: {list(image.keys())}")
47
+ elif not isinstance(image, Image.Image):
48
+ raise gr.Error(f"Unexpected image type: {type(image)}")
49
+ print(f"Image size: {image.size}, mode: {image.mode}")
50
+ inputs = processor(images=[image], return_tensors="pt")
51
+ with torch.no_grad():
52
+ feats = model.get_image_features(pixel_values=inputs["pixel_values"])
53
+ if not isinstance(feats, torch.Tensor):
54
+ feats = feats.pooler_output
55
+ feats = F.normalize(feats, dim=-1)
56
+ return feats[0].tolist()
57
+
58
+ with gr.Blocks(title="SigLIP 2 Encoder") as demo:
59
+ gr.Markdown("# SigLIP 2 Encoder\nEncodes text or images to 768-dim normalized vectors using google/siglip2-base-patch16-224")
60
+
61
+ with gr.Tab("Text"):
62
+ text_input = gr.Textbox(label="Text")
63
+ text_output = gr.JSON(label="Embedding (768-dim)")
64
+ text_btn = gr.Button("Encode Text")
65
+ text_btn.click(fn=encode, inputs=text_input, outputs=text_output, api_name="encode")
66
+
67
+ with gr.Tab("Image"):
68
+ image_input = gr.Image(type="pil", label="Image")
69
+ image_output = gr.JSON(label="Embedding (768-dim)")
70
+ image_btn = gr.Button("Encode Image")
71
+ image_btn.click(fn=encode_image, inputs=image_input, outputs=image_output, api_name="encode_image")
72
+
73
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.49.0
3
+ sentencepiece
4
+ protobuf
5
+ Pillow