malusama's picture
Set tokenizer special tokens and use numpy image input
08dd3b5 verified
from functools import lru_cache
import importlib
import json
import os
import sys
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import AutoModel
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
MODEL_ID = "malusama/M2-Encoder-0.4B"
MODEL_REVISION = "5b673bc65a31d72c9245ad7a161ba5a378f6ad88"
DEVICE = torch.device("cpu")
@lru_cache(maxsize=1)
def load_components():
model_dir = snapshot_download(
repo_id=MODEL_ID,
revision=MODEL_REVISION,
)
if model_dir not in sys.path:
sys.path.insert(0, model_dir)
model = AutoModel.from_pretrained(
model_dir,
trust_remote_code=True,
)
tokenizer_config = json.load(
open(os.path.join(model_dir, "tokenizer_config.json"), "r", encoding="utf-8")
)
tokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer(
vocab_file=os.path.join(model_dir, "sp.model"),
eos_token=tokenizer_config.get("eos_token"),
pad_token=tokenizer_config.get("pad_token"),
cls_token=tokenizer_config.get("cls_token"),
mask_token=tokenizer_config.get("mask_token"),
unk_token=tokenizer_config.get("unk_token"),
)
image_processor = importlib.import_module(
"image_processing_m2_encoder"
).M2EncoderImageProcessor.from_pretrained(model_dir)
model.to(DEVICE)
model.eval()
return model, tokenizer, image_processor
def parse_labels(text: str):
items = []
for raw in text.splitlines():
for part in raw.split(","):
label = part.strip()
if label:
items.append(label)
return items
def run_demo(image: Image.Image, candidate_text: str):
labels = parse_labels(candidate_text)
if image is None:
raise ValueError("Please upload an image.")
if not labels:
raise ValueError("Please enter at least one label.")
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
model, tokenizer, image_processor = load_components()
with torch.no_grad():
text_inputs = tokenizer(
labels,
padding="max_length",
truncation=True,
max_length=52,
return_special_tokens_mask=True,
return_tensors="pt",
)
image_inputs = image_processor(image.convert("RGB"), return_tensors="pt")
text_inputs = {
key: value.to(DEVICE) if hasattr(value, "to") else value
for key, value in text_inputs.items()
}
image_inputs = {
key: value.to(DEVICE) if hasattr(value, "to") else value
for key, value in image_inputs.items()
}
text_outputs = model(**text_inputs)
image_outputs = model(**image_inputs)
scores = (image_outputs.image_embeds @ text_outputs.text_embeds.t()).squeeze(0)
probs = scores.softmax(dim=-1)
rows = [
(label, float(score), float(prob))
for label, score, prob in zip(labels, scores.tolist(), probs.tolist())
]
rows.sort(key=lambda row: row[2], reverse=True)
top_label = rows[0][0]
top_prob = rows[0][2]
summary = f"Top match: {top_label} ({top_prob:.4f})"
details = {
"ranked_results": [
{"label": label, "score": score, "prob": prob}
for label, score, prob in rows
]
}
return summary, json.dumps(details, ensure_ascii=False, indent=2)
def build_demo():
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown(
"""
# M2-Encoder 0.4B
Upload one image and enter candidate labels, one per line or comma-separated.
This Space runs on `CPU Basic`, so the first request can be slow.
"""
)
with gr.Row():
image_input = gr.Image(type="numpy", label="Image")
labels_input = gr.Textbox(
label="Candidate Labels",
lines=8,
value="杰尼龟\n妙蛙种子\n小火龙\n皮卡丘",
)
run_button = gr.Button("Run Matching", variant="primary")
summary_output = gr.Textbox(label="Summary")
details_output = gr.Textbox(label="Results JSON", lines=18)
run_button.click(
run_demo,
inputs=[image_input, labels_input],
outputs=[summary_output, details_output],
api_name=False,
show_api=False,
)
return demo
try:
demo = build_demo()
except ModuleNotFoundError:
demo = None
if __name__ == "__main__":
if demo is None:
raise RuntimeError("gradio is required to launch this app.")
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)