Spaces:
Sleeping
Sleeping
| from functools import lru_cache | |
| import importlib | |
| import json | |
| import os | |
| import sys | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| from transformers import AutoModel | |
| os.environ["HF_ENDPOINT"] = "https://huggingface.co" | |
| MODEL_ID = "malusama/M2-Encoder-0.4B" | |
| MODEL_REVISION = "5b673bc65a31d72c9245ad7a161ba5a378f6ad88" | |
| DEVICE = torch.device("cpu") | |
| def load_components(): | |
| model_dir = snapshot_download( | |
| repo_id=MODEL_ID, | |
| revision=MODEL_REVISION, | |
| ) | |
| if model_dir not in sys.path: | |
| sys.path.insert(0, model_dir) | |
| model = AutoModel.from_pretrained( | |
| model_dir, | |
| trust_remote_code=True, | |
| ) | |
| tokenizer_config = json.load( | |
| open(os.path.join(model_dir, "tokenizer_config.json"), "r", encoding="utf-8") | |
| ) | |
| tokenizer = importlib.import_module("tokenization_glm").GLMChineseTokenizer( | |
| vocab_file=os.path.join(model_dir, "sp.model"), | |
| eos_token=tokenizer_config.get("eos_token"), | |
| pad_token=tokenizer_config.get("pad_token"), | |
| cls_token=tokenizer_config.get("cls_token"), | |
| mask_token=tokenizer_config.get("mask_token"), | |
| unk_token=tokenizer_config.get("unk_token"), | |
| ) | |
| image_processor = importlib.import_module( | |
| "image_processing_m2_encoder" | |
| ).M2EncoderImageProcessor.from_pretrained(model_dir) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model, tokenizer, image_processor | |
| def parse_labels(text: str): | |
| items = [] | |
| for raw in text.splitlines(): | |
| for part in raw.split(","): | |
| label = part.strip() | |
| if label: | |
| items.append(label) | |
| return items | |
| def run_demo(image: Image.Image, candidate_text: str): | |
| labels = parse_labels(candidate_text) | |
| if image is None: | |
| raise ValueError("Please upload an image.") | |
| if not labels: | |
| raise ValueError("Please enter at least one label.") | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| model, tokenizer, image_processor = load_components() | |
| with torch.no_grad(): | |
| text_inputs = tokenizer( | |
| labels, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=52, | |
| return_special_tokens_mask=True, | |
| return_tensors="pt", | |
| ) | |
| image_inputs = image_processor(image.convert("RGB"), return_tensors="pt") | |
| text_inputs = { | |
| key: value.to(DEVICE) if hasattr(value, "to") else value | |
| for key, value in text_inputs.items() | |
| } | |
| image_inputs = { | |
| key: value.to(DEVICE) if hasattr(value, "to") else value | |
| for key, value in image_inputs.items() | |
| } | |
| text_outputs = model(**text_inputs) | |
| image_outputs = model(**image_inputs) | |
| scores = (image_outputs.image_embeds @ text_outputs.text_embeds.t()).squeeze(0) | |
| probs = scores.softmax(dim=-1) | |
| rows = [ | |
| (label, float(score), float(prob)) | |
| for label, score, prob in zip(labels, scores.tolist(), probs.tolist()) | |
| ] | |
| rows.sort(key=lambda row: row[2], reverse=True) | |
| top_label = rows[0][0] | |
| top_prob = rows[0][2] | |
| summary = f"Top match: {top_label} ({top_prob:.4f})" | |
| details = { | |
| "ranked_results": [ | |
| {"label": label, "score": score, "prob": prob} | |
| for label, score, prob in rows | |
| ] | |
| } | |
| return summary, json.dumps(details, ensure_ascii=False, indent=2) | |
| def build_demo(): | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # M2-Encoder 0.4B | |
| Upload one image and enter candidate labels, one per line or comma-separated. | |
| This Space runs on `CPU Basic`, so the first request can be slow. | |
| """ | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image(type="numpy", label="Image") | |
| labels_input = gr.Textbox( | |
| label="Candidate Labels", | |
| lines=8, | |
| value="杰尼龟\n妙蛙种子\n小火龙\n皮卡丘", | |
| ) | |
| run_button = gr.Button("Run Matching", variant="primary") | |
| summary_output = gr.Textbox(label="Summary") | |
| details_output = gr.Textbox(label="Results JSON", lines=18) | |
| run_button.click( | |
| run_demo, | |
| inputs=[image_input, labels_input], | |
| outputs=[summary_output, details_output], | |
| api_name=False, | |
| show_api=False, | |
| ) | |
| return demo | |
| try: | |
| demo = build_demo() | |
| except ModuleNotFoundError: | |
| demo = None | |
| if __name__ == "__main__": | |
| if demo is None: | |
| raise RuntimeError("gradio is required to launch this app.") | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False) | |