Spaces:
Sleeping
Sleeping
File size: 4,784 Bytes
fba79e3 a44eea8 8cec082 fba79e3 a44eea8 fba79e3 a44eea8 fba79e3 a44eea8 fba79e3 08dd3b5 a44eea8 08dd3b5 fba79e3 a44eea8 fba79e3 a44eea8 fba79e3 08dd3b5 fba79e3 a44eea8 fba79e3 a44eea8 fba79e3 0ab025b fba79e3 0ab025b fba79e3 0ab025b fba79e3 0ab025b fba79e3 0ab025b fba79e3 08dd3b5 fba79e3 0ab025b fba79e3 0ab025b fba79e3 0ab025b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | from functools import lru_cache
import importlib
import json
import os
import sys
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import AutoModel
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
MODEL_ID = "malusama/M2-Encoder-0.4B"
MODEL_REVISION = "5b673bc65a31d72c9245ad7a161ba5a378f6ad88"
DEVICE = torch.device("cpu")
@lru_cache(maxsize=1)
def load_components():
model_dir = snapshot_download(
repo_id=MODEL_ID,
revision=MODEL_REVISION,
)
if model_dir not in sys.path:
sys.path.insert(0, model_dir)
model = AutoModel.from_pretrained(
model_dir,
trust_remote_code=True,
)
tokenizer_config = json.load(
open(os.path.join(model_dir, "tokenizer_config.json"), "r", encoding="utf-8")
)
tokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer(
vocab_file=os.path.join(model_dir, "sp.model"),
eos_token=tokenizer_config.get("eos_token"),
pad_token=tokenizer_config.get("pad_token"),
cls_token=tokenizer_config.get("cls_token"),
mask_token=tokenizer_config.get("mask_token"),
unk_token=tokenizer_config.get("unk_token"),
)
image_processor = importlib.import_module(
"image_processing_m2_encoder"
).M2EncoderImageProcessor.from_pretrained(model_dir)
model.to(DEVICE)
model.eval()
return model, tokenizer, image_processor
def parse_labels(text: str):
items = []
for raw in text.splitlines():
for part in raw.split(","):
label = part.strip()
if label:
items.append(label)
return items
def run_demo(image: Image.Image, candidate_text: str):
labels = parse_labels(candidate_text)
if image is None:
raise ValueError("Please upload an image.")
if not labels:
raise ValueError("Please enter at least one label.")
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
model, tokenizer, image_processor = load_components()
with torch.no_grad():
text_inputs = tokenizer(
labels,
padding="max_length",
truncation=True,
max_length=52,
return_special_tokens_mask=True,
return_tensors="pt",
)
image_inputs = image_processor(image.convert("RGB"), return_tensors="pt")
text_inputs = {
key: value.to(DEVICE) if hasattr(value, "to") else value
for key, value in text_inputs.items()
}
image_inputs = {
key: value.to(DEVICE) if hasattr(value, "to") else value
for key, value in image_inputs.items()
}
text_outputs = model(**text_inputs)
image_outputs = model(**image_inputs)
scores = (image_outputs.image_embeds @ text_outputs.text_embeds.t()).squeeze(0)
probs = scores.softmax(dim=-1)
rows = [
(label, float(score), float(prob))
for label, score, prob in zip(labels, scores.tolist(), probs.tolist())
]
rows.sort(key=lambda row: row[2], reverse=True)
top_label = rows[0][0]
top_prob = rows[0][2]
summary = f"Top match: {top_label} ({top_prob:.4f})"
details = {
"ranked_results": [
{"label": label, "score": score, "prob": prob}
for label, score, prob in rows
]
}
return summary, json.dumps(details, ensure_ascii=False, indent=2)
def build_demo():
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown(
"""
# M2-Encoder 0.4B
Upload one image and enter candidate labels, one per line or comma-separated.
This Space runs on `CPU Basic`, so the first request can be slow.
"""
)
with gr.Row():
image_input = gr.Image(type="numpy", label="Image")
labels_input = gr.Textbox(
label="Candidate Labels",
lines=8,
value="杰尼龟\n妙蛙种子\n小火龙\n皮卡丘",
)
run_button = gr.Button("Run Matching", variant="primary")
summary_output = gr.Textbox(label="Summary")
details_output = gr.Textbox(label="Results JSON", lines=18)
run_button.click(
run_demo,
inputs=[image_input, labels_input],
outputs=[summary_output, details_output],
api_name=False,
show_api=False,
)
return demo
try:
demo = build_demo()
except ModuleNotFoundError:
demo = None
if __name__ == "__main__":
if demo is None:
raise RuntimeError("gradio is required to launch this app.")
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
|