| """ |
| matcha.py — Simple interface for loading and using MATCHA. |
| |
| Usage: |
| from matcha_metric import MATCHA |
| |
| model = MATCHA.from_pretrained("Siran-Li/MATCHA") |
| similarity = model.score("A dog plays in the park.", "A puppy runs outside.") |
| embeddings = model.encode(["Hello world", "Hi there"]) |
| """ |
|
|
| import importlib.util |
| import json |
| import os |
| import sys |
| from types import SimpleNamespace |
|
|
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import hf_hub_download |
| from transformers import GPT2Model, GPT2Tokenizer |
|
|
|
|
| def _load_model_module(model_py_path): |
| """Dynamically import model.py from a file path.""" |
| spec = importlib.util.spec_from_file_location("matcha_model", model_py_path) |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| class MATCHA: |
| def __init__(self, model, tokenizer, max_length=512, device=None): |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = model.to(self.device).eval() |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| @classmethod |
| def from_pretrained(cls, repo_id="Siran-Li/MATCHA", device=None): |
| """Load MATCHA from a HuggingFace Hub repo or local directory.""" |
| if os.path.isdir(repo_id): |
| model_dir = repo_id |
| config_path = os.path.join(model_dir, "model_config.json") |
| checkpoint_path = os.path.join(model_dir, "max_diff.pth") |
| model_py_path = os.path.join(model_dir, "model.py") |
| else: |
| config_path = hf_hub_download(repo_id, "model_config.json") |
| checkpoint_path = hf_hub_download(repo_id, "max_diff.pth") |
| model_py_path = hf_hub_download(repo_id, "model.py") |
|
|
| with open(config_path) as f: |
| model_config = SimpleNamespace(**json.load(f)) |
|
|
| model_module = _load_model_module(model_py_path) |
|
|
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
| backbone = GPT2Model.from_pretrained("gpt2") |
|
|
| model = model_module.ContrastiveModel(backbone, model_config) |
| device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| weights = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(weights["model_state_dict"]) |
|
|
| return cls(model, tokenizer, device=device) |
|
|
| def encode(self, texts, batch_size=32): |
| """Encode texts into embeddings. |
| |
| Args: |
| texts: A string or list of strings. |
| batch_size: Batch size for encoding. |
| |
| Returns: |
| Tensor of shape (n, embedding_dim). |
| """ |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| all_embeddings = [] |
| for i in range(0, len(texts), batch_size): |
| batch = texts[i : i + batch_size] |
| inputs = self.tokenizer( |
| batch, |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=self.max_length, |
| ).to(self.device) |
| with torch.no_grad(): |
| embeddings = self.model(inputs["input_ids"]) |
| all_embeddings.append(embeddings) |
|
|
| return torch.cat(all_embeddings, dim=0) |
|
|
| def score(self, text1, text2): |
| """Compute cosine similarity between two texts. |
| |
| Args: |
| text1: A string or list of strings. |
| text2: A string or list of strings. |
| |
| Returns: |
| Float (if inputs are strings) or Tensor of similarities. |
| """ |
| single = isinstance(text1, str) and isinstance(text2, str) |
| emb1 = self.encode(text1) |
| emb2 = self.encode(text2) |
| similarity = F.cosine_similarity(emb1, emb2, dim=-1) |
| return similarity.item() if single else similarity |
|
|
| def _attribute_single(self, input_ids, ref_emb, lig, n_steps): |
| """Run Integrated Gradients for one direction and extract tokens + scores.""" |
| def forward_func(input_ids, ref_emb=None): |
| out_emb = self.model(input_ids) |
| return F.cosine_similarity(ref_emb, out_emb, dim=-1) |
|
|
| similarity = forward_func(input_ids, ref_emb).item() |
|
|
| attributions_ig, delta = lig.attribute( |
| inputs=input_ids, |
| baselines=torch.zeros_like(input_ids).to(self.device), |
| additional_forward_args=(ref_emb,), |
| return_convergence_delta=True, |
| n_steps=n_steps, |
| ) |
|
|
| token_attr = attributions_ig.sum(dim=-1).squeeze(0) |
| all_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0]) |
| pad_token = self.tokenizer.pad_token |
| tokens = [] |
| scores = [] |
| for i, t in enumerate(all_tokens): |
| if t == pad_token: |
| break |
| tokens.append(t.replace("Ġ", "")) |
| scores.append(token_attr[i].item()) |
|
|
| return {"tokens": tokens, "attributions": scores, "similarity": similarity} |
|
|
| def interpret(self, text, reference, n_steps=50): |
| """Compute bidirectional token-level attributions using Integrated Gradients. |
| |
| Runs two directions: |
| 1. text → reference: which tokens in `text` contribute to similarity |
| 2. reference → text: which tokens in `reference` contribute to similarity |
| |
| Requires: pip install captum |
| |
| Args: |
| text: First input text. |
| reference: Second input text. |
| n_steps: Number of interpolation steps for Integrated Gradients. |
| |
| Returns: |
| dict with keys: |
| - text_to_ref: {tokens, attributions, similarity} |
| - ref_to_text: {tokens, attributions, similarity} |
| """ |
| try: |
| from captum.attr import LayerIntegratedGradients |
| except ImportError: |
| raise ImportError("captum is required for interpret(). Install with: pip install captum") |
|
|
| text_inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", |
| truncation=True, max_length=self.max_length).to(self.device) |
| ref_inputs = self.tokenizer(reference, return_tensors="pt", padding="max_length", |
| truncation=True, max_length=self.max_length).to(self.device) |
|
|
| text_ids = text_inputs["input_ids"] |
| ref_ids = ref_inputs["input_ids"] |
|
|
| def forward_func(input_ids, ref_emb=None): |
| out_emb = self.model(input_ids) |
| return F.cosine_similarity(ref_emb, out_emb, dim=-1) |
|
|
| embedding_layer = self.model.word_embeddings |
| lig = LayerIntegratedGradients(forward_func, embedding_layer) |
|
|
| |
| with torch.no_grad(): |
| ref_emb = self.model(ref_ids) |
| text_to_ref = self._attribute_single(text_ids, ref_emb, lig, n_steps) |
|
|
| |
| with torch.no_grad(): |
| text_emb = self.model(text_ids) |
| ref_to_text = self._attribute_single(ref_ids, text_emb, lig, n_steps) |
|
|
| return {"text_to_ref": text_to_ref, "ref_to_text": ref_to_text} |
|
|
| def visualize(self, text, reference, label=None, output_path="attribution.html", n_steps=50): |
| """Run bidirectional interpret() and save an HTML heatmap visualization. |
| |
| Requires: pip install captum |
| |
| Args: |
| text: First input text. |
| reference: Second input text. |
| label: Optional ground truth label ("Correct" or "Incorrect"). |
| output_path: Path to save the HTML file. |
| n_steps: Number of interpolation steps. |
| |
| Returns: |
| dict from interpret() with text_to_ref and ref_to_text. |
| """ |
| try: |
| from captum.attr import visualization |
| except ImportError: |
| raise ImportError("captum is required for visualize(). Install with: pip install captum") |
|
|
| import numpy as np |
|
|
| result = self.interpret(text, reference, n_steps=n_steps) |
| vis_records = [] |
|
|
| true_label = label or "N/A" |
|
|
| for direction, dir_label in [("text_to_ref", "text → ref"), ("ref_to_text", "ref → text")]: |
| r = result[direction] |
| attr_array = np.array(r["attributions"]) |
| norm = np.linalg.norm(attr_array) |
| if norm > 0: |
| attr_array = attr_array / norm |
|
|
| sim_score = round(r["similarity"] * 100, 2) |
| pred_label = "Correct" if r["similarity"] > 0 else "Incorrect" |
|
|
| vis_records.append(visualization.VisualizationDataRecord( |
| attr_array, |
| sim_score, |
| pred_label, |
| true_label, |
| true_label, |
| round(sum(r["attributions"]) * 100, 2), |
| r["tokens"], |
| None, |
| )) |
|
|
| vis = visualization.visualize_text(vis_records) |
| with open(output_path, "w") as f: |
| f.write(vis.data) |
|
|
| print(f"Saved attribution visualization to {output_path}") |
| return result |
|
|