from functools import lru_cache import importlib import json import os import sys import torch from huggingface_hub import snapshot_download from PIL import Image from transformers import AutoModel os.environ["HF_ENDPOINT"] = "https://huggingface.co" MODEL_ID = "malusama/M2-Encoder-0.4B" MODEL_REVISION = "5b673bc65a31d72c9245ad7a161ba5a378f6ad88" DEVICE = torch.device("cpu") @lru_cache(maxsize=1) def load_components(): model_dir = snapshot_download( repo_id=MODEL_ID, revision=MODEL_REVISION, ) if model_dir not in sys.path: sys.path.insert(0, model_dir) model = AutoModel.from_pretrained( model_dir, trust_remote_code=True, ) tokenizer_config = json.load( open(os.path.join(model_dir, "tokenizer_config.json"), "r", encoding="utf-8") ) tokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer( vocab_file=os.path.join(model_dir, "sp.model"), eos_token=tokenizer_config.get("eos_token"), pad_token=tokenizer_config.get("pad_token"), cls_token=tokenizer_config.get("cls_token"), mask_token=tokenizer_config.get("mask_token"), unk_token=tokenizer_config.get("unk_token"), ) image_processor = importlib.import_module( "image_processing_m2_encoder" ).M2EncoderImageProcessor.from_pretrained(model_dir) model.to(DEVICE) model.eval() return model, tokenizer, image_processor def parse_labels(text: str): items = [] for raw in text.splitlines(): for part in raw.split(","): label = part.strip() if label: items.append(label) return items def run_demo(image: Image.Image, candidate_text: str): labels = parse_labels(candidate_text) if image is None: raise ValueError("Please upload an image.") if not labels: raise ValueError("Please enter at least one label.") if not isinstance(image, Image.Image): image = Image.fromarray(image) model, tokenizer, image_processor = load_components() with torch.no_grad(): text_inputs = tokenizer( labels, padding="max_length", truncation=True, max_length=52, return_special_tokens_mask=True, return_tensors="pt", ) image_inputs = image_processor(image.convert("RGB"), return_tensors="pt") text_inputs = { key: value.to(DEVICE) if hasattr(value, "to") else value for key, value in text_inputs.items() } image_inputs = { key: value.to(DEVICE) if hasattr(value, "to") else value for key, value in image_inputs.items() } text_outputs = model(**text_inputs) image_outputs = model(**image_inputs) scores = (image_outputs.image_embeds @ text_outputs.text_embeds.t()).squeeze(0) probs = scores.softmax(dim=-1) rows = [ (label, float(score), float(prob)) for label, score, prob in zip(labels, scores.tolist(), probs.tolist()) ] rows.sort(key=lambda row: row[2], reverse=True) top_label = rows[0][0] top_prob = rows[0][2] summary = f"Top match: {top_label} ({top_prob:.4f})" details = { "ranked_results": [ {"label": label, "score": score, "prob": prob} for label, score, prob in rows ] } return summary, json.dumps(details, ensure_ascii=False, indent=2) def build_demo(): import gradio as gr with gr.Blocks() as demo: gr.Markdown( """ # M2-Encoder 0.4B Upload one image and enter candidate labels, one per line or comma-separated. This Space runs on `CPU Basic`, so the first request can be slow. """ ) with gr.Row(): image_input = gr.Image(type="numpy", label="Image") labels_input = gr.Textbox( label="Candidate Labels", lines=8, value="杰尼龟\n妙蛙种子\n小火龙\n皮卡丘", ) run_button = gr.Button("Run Matching", variant="primary") summary_output = gr.Textbox(label="Summary") details_output = gr.Textbox(label="Results JSON", lines=18) run_button.click( run_demo, inputs=[image_input, labels_input], outputs=[summary_output, details_output], api_name=False, show_api=False, ) return demo try: demo = build_demo() except ModuleNotFoundError: demo = None if __name__ == "__main__": if demo is None: raise RuntimeError("gradio is required to launch this app.") demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)