""" matcha.py — Simple interface for loading and using MATCHA. Usage: from matcha_metric import MATCHA model = MATCHA.from_pretrained("Siran-Li/MATCHA") similarity = model.score("A dog plays in the park.", "A puppy runs outside.") embeddings = model.encode(["Hello world", "Hi there"]) """ import importlib.util import json import os import sys from types import SimpleNamespace import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from transformers import GPT2Model, GPT2Tokenizer def _load_model_module(model_py_path): """Dynamically import model.py from a file path.""" spec = importlib.util.spec_from_file_location("matcha_model", model_py_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module class MATCHA: def __init__(self, model, tokenizer, max_length=512, device=None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = model.to(self.device).eval() self.tokenizer = tokenizer self.max_length = max_length @classmethod def from_pretrained(cls, repo_id="Siran-Li/MATCHA", device=None): """Load MATCHA from a HuggingFace Hub repo or local directory.""" if os.path.isdir(repo_id): model_dir = repo_id config_path = os.path.join(model_dir, "model_config.json") checkpoint_path = os.path.join(model_dir, "max_diff.pth") model_py_path = os.path.join(model_dir, "model.py") else: config_path = hf_hub_download(repo_id, "model_config.json") checkpoint_path = hf_hub_download(repo_id, "max_diff.pth") model_py_path = hf_hub_download(repo_id, "model.py") with open(config_path) as f: model_config = SimpleNamespace(**json.load(f)) model_module = _load_model_module(model_py_path) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token backbone = GPT2Model.from_pretrained("gpt2") model = model_module.ContrastiveModel(backbone, model_config) device = device or ("cuda" if torch.cuda.is_available() else "cpu") weights = torch.load(checkpoint_path, map_location=device) model.load_state_dict(weights["model_state_dict"]) return cls(model, tokenizer, device=device) def encode(self, texts, batch_size=32): """Encode texts into embeddings. Args: texts: A string or list of strings. batch_size: Batch size for encoding. Returns: Tensor of shape (n, embedding_dim). """ if isinstance(texts, str): texts = [texts] all_embeddings = [] for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] inputs = self.tokenizer( batch, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length, ).to(self.device) with torch.no_grad(): embeddings = self.model(inputs["input_ids"]) all_embeddings.append(embeddings) return torch.cat(all_embeddings, dim=0) def score(self, text1, text2): """Compute cosine similarity between two texts. Args: text1: A string or list of strings. text2: A string or list of strings. Returns: Float (if inputs are strings) or Tensor of similarities. """ single = isinstance(text1, str) and isinstance(text2, str) emb1 = self.encode(text1) emb2 = self.encode(text2) similarity = F.cosine_similarity(emb1, emb2, dim=-1) return similarity.item() if single else similarity def _attribute_single(self, input_ids, ref_emb, lig, n_steps): """Run Integrated Gradients for one direction and extract tokens + scores.""" def forward_func(input_ids, ref_emb=None): out_emb = self.model(input_ids) return F.cosine_similarity(ref_emb, out_emb, dim=-1) similarity = forward_func(input_ids, ref_emb).item() attributions_ig, delta = lig.attribute( inputs=input_ids, baselines=torch.zeros_like(input_ids).to(self.device), additional_forward_args=(ref_emb,), return_convergence_delta=True, n_steps=n_steps, ) token_attr = attributions_ig.sum(dim=-1).squeeze(0) all_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0]) pad_token = self.tokenizer.pad_token tokens = [] scores = [] for i, t in enumerate(all_tokens): if t == pad_token: break tokens.append(t.replace("Ġ", "")) scores.append(token_attr[i].item()) return {"tokens": tokens, "attributions": scores, "similarity": similarity} def interpret(self, text, reference, n_steps=50): """Compute bidirectional token-level attributions using Integrated Gradients. Runs two directions: 1. text → reference: which tokens in `text` contribute to similarity 2. reference → text: which tokens in `reference` contribute to similarity Requires: pip install captum Args: text: First input text. reference: Second input text. n_steps: Number of interpolation steps for Integrated Gradients. Returns: dict with keys: - text_to_ref: {tokens, attributions, similarity} - ref_to_text: {tokens, attributions, similarity} """ try: from captum.attr import LayerIntegratedGradients except ImportError: raise ImportError("captum is required for interpret(). Install with: pip install captum") text_inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length).to(self.device) ref_inputs = self.tokenizer(reference, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length).to(self.device) text_ids = text_inputs["input_ids"] ref_ids = ref_inputs["input_ids"] def forward_func(input_ids, ref_emb=None): out_emb = self.model(input_ids) return F.cosine_similarity(ref_emb, out_emb, dim=-1) embedding_layer = self.model.word_embeddings lig = LayerIntegratedGradients(forward_func, embedding_layer) # Direction 1: text → reference with torch.no_grad(): ref_emb = self.model(ref_ids) text_to_ref = self._attribute_single(text_ids, ref_emb, lig, n_steps) # Direction 2: reference → text with torch.no_grad(): text_emb = self.model(text_ids) ref_to_text = self._attribute_single(ref_ids, text_emb, lig, n_steps) return {"text_to_ref": text_to_ref, "ref_to_text": ref_to_text} def visualize(self, text, reference, label=None, output_path="attribution.html", n_steps=50): """Run bidirectional interpret() and save an HTML heatmap visualization. Requires: pip install captum Args: text: First input text. reference: Second input text. label: Optional ground truth label ("Correct" or "Incorrect"). output_path: Path to save the HTML file. n_steps: Number of interpolation steps. Returns: dict from interpret() with text_to_ref and ref_to_text. """ try: from captum.attr import visualization except ImportError: raise ImportError("captum is required for visualize(). Install with: pip install captum") import numpy as np result = self.interpret(text, reference, n_steps=n_steps) vis_records = [] true_label = label or "N/A" for direction, dir_label in [("text_to_ref", "text → ref"), ("ref_to_text", "ref → text")]: r = result[direction] attr_array = np.array(r["attributions"]) norm = np.linalg.norm(attr_array) if norm > 0: attr_array = attr_array / norm sim_score = round(r["similarity"] * 100, 2) pred_label = "Correct" if r["similarity"] > 0 else "Incorrect" vis_records.append(visualization.VisualizationDataRecord( attr_array, sim_score, pred_label, true_label, true_label, round(sum(r["attributions"]) * 100, 2), r["tokens"], None, )) vis = visualization.visualize_text(vis_records) with open(output_path, "w") as f: f.write(vis.data) print(f"Saved attribution visualization to {output_path}") return result