tcm03 commited on
Commit ·
af896ec
1
Parent(s): e9e7244
Add custom inference script for text and sketch
Browse files- inference.py +64 -0
inference.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import base64
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.append("code")
|
| 9 |
+
from clip.model import CLIP
|
| 10 |
+
|
| 11 |
+
# Load Model and Utilities
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
model = CLIP.from_pretrained("tcm03/tsbir").to(device)
|
| 14 |
+
model.eval()
|
| 15 |
+
|
| 16 |
+
# Preprocessing Functions
|
| 17 |
+
from clip.clip import _transform, tokenize
|
| 18 |
+
transformer = _transform(model.visual.input_resolution, is_train=False)
|
| 19 |
+
|
| 20 |
+
def preprocess_image(image_base64):
|
| 21 |
+
"""Convert base64 encoded image to tensor."""
|
| 22 |
+
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
|
| 23 |
+
image = transformer(image).unsqueeze(0).to(device)
|
| 24 |
+
return image
|
| 25 |
+
|
| 26 |
+
def preprocess_text(text):
|
| 27 |
+
"""Tokenize text query."""
|
| 28 |
+
return tokenize([str(text)])[0].unsqueeze(0).to(device)
|
| 29 |
+
|
| 30 |
+
def get_fused_embedding(image_base64, text):
|
| 31 |
+
"""Fuse sketch and text features into a single embedding."""
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
# Preprocess Inputs
|
| 34 |
+
image_tensor = preprocess_image(image_base64)
|
| 35 |
+
text_tensor = preprocess_text(text)
|
| 36 |
+
|
| 37 |
+
# Extract Features
|
| 38 |
+
sketch_feature = model.encode_sketch(image_tensor)
|
| 39 |
+
text_feature = model.encode_text(text_tensor)
|
| 40 |
+
|
| 41 |
+
# Normalize Features
|
| 42 |
+
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
|
| 43 |
+
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
|
| 44 |
+
|
| 45 |
+
# Fuse Features
|
| 46 |
+
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
|
| 47 |
+
return fused_embedding.cpu().numpy().tolist()
|
| 48 |
+
|
| 49 |
+
# Hugging Face Inference API Entry Point
|
| 50 |
+
def infer(inputs):
|
| 51 |
+
"""
|
| 52 |
+
Inference API entry point.
|
| 53 |
+
Inputs:
|
| 54 |
+
- 'image': Base64 encoded sketch image.
|
| 55 |
+
- 'text': Text query.
|
| 56 |
+
"""
|
| 57 |
+
image_base64 = inputs.get("image", "")
|
| 58 |
+
text_query = inputs.get("text", "")
|
| 59 |
+
if not image_base64 or not text_query:
|
| 60 |
+
return {"error": "Both 'image' (base64) and 'text' are required inputs."}
|
| 61 |
+
|
| 62 |
+
# Generate Fused Embedding
|
| 63 |
+
fused_embedding = get_fused_embedding(image_base64, text_query)
|
| 64 |
+
return {"fused_embedding": fused_embedding}
|