MATCHA / matcha.py
Siran-Li's picture
Upload matcha.py with huggingface_hub
852f583 verified
"""
matcha.py — Simple interface for loading and using MATCHA.
Usage:
from matcha_metric import MATCHA
model = MATCHA.from_pretrained("Siran-Li/MATCHA")
similarity = model.score("A dog plays in the park.", "A puppy runs outside.")
embeddings = model.encode(["Hello world", "Hi there"])
"""
import importlib.util
import json
import os
import sys
from types import SimpleNamespace
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from transformers import GPT2Model, GPT2Tokenizer
def _load_model_module(model_py_path):
"""Dynamically import model.py from a file path."""
spec = importlib.util.spec_from_file_location("matcha_model", model_py_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class MATCHA:
def __init__(self, model, tokenizer, max_length=512, device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = model.to(self.device).eval()
self.tokenizer = tokenizer
self.max_length = max_length
@classmethod
def from_pretrained(cls, repo_id="Siran-Li/MATCHA", device=None):
"""Load MATCHA from a HuggingFace Hub repo or local directory."""
if os.path.isdir(repo_id):
model_dir = repo_id
config_path = os.path.join(model_dir, "model_config.json")
checkpoint_path = os.path.join(model_dir, "max_diff.pth")
model_py_path = os.path.join(model_dir, "model.py")
else:
config_path = hf_hub_download(repo_id, "model_config.json")
checkpoint_path = hf_hub_download(repo_id, "max_diff.pth")
model_py_path = hf_hub_download(repo_id, "model.py")
with open(config_path) as f:
model_config = SimpleNamespace(**json.load(f))
model_module = _load_model_module(model_py_path)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
backbone = GPT2Model.from_pretrained("gpt2")
model = model_module.ContrastiveModel(backbone, model_config)
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
weights = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(weights["model_state_dict"])
return cls(model, tokenizer, device=device)
def encode(self, texts, batch_size=32):
"""Encode texts into embeddings.
Args:
texts: A string or list of strings.
batch_size: Batch size for encoding.
Returns:
Tensor of shape (n, embedding_dim).
"""
if isinstance(texts, str):
texts = [texts]
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
inputs = self.tokenizer(
batch,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length,
).to(self.device)
with torch.no_grad():
embeddings = self.model(inputs["input_ids"])
all_embeddings.append(embeddings)
return torch.cat(all_embeddings, dim=0)
def score(self, text1, text2):
"""Compute cosine similarity between two texts.
Args:
text1: A string or list of strings.
text2: A string or list of strings.
Returns:
Float (if inputs are strings) or Tensor of similarities.
"""
single = isinstance(text1, str) and isinstance(text2, str)
emb1 = self.encode(text1)
emb2 = self.encode(text2)
similarity = F.cosine_similarity(emb1, emb2, dim=-1)
return similarity.item() if single else similarity
def _attribute_single(self, input_ids, ref_emb, lig, n_steps):
"""Run Integrated Gradients for one direction and extract tokens + scores."""
def forward_func(input_ids, ref_emb=None):
out_emb = self.model(input_ids)
return F.cosine_similarity(ref_emb, out_emb, dim=-1)
similarity = forward_func(input_ids, ref_emb).item()
attributions_ig, delta = lig.attribute(
inputs=input_ids,
baselines=torch.zeros_like(input_ids).to(self.device),
additional_forward_args=(ref_emb,),
return_convergence_delta=True,
n_steps=n_steps,
)
token_attr = attributions_ig.sum(dim=-1).squeeze(0)
all_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
pad_token = self.tokenizer.pad_token
tokens = []
scores = []
for i, t in enumerate(all_tokens):
if t == pad_token:
break
tokens.append(t.replace("Ġ", ""))
scores.append(token_attr[i].item())
return {"tokens": tokens, "attributions": scores, "similarity": similarity}
def interpret(self, text, reference, n_steps=50):
"""Compute bidirectional token-level attributions using Integrated Gradients.
Runs two directions:
1. text → reference: which tokens in `text` contribute to similarity
2. reference → text: which tokens in `reference` contribute to similarity
Requires: pip install captum
Args:
text: First input text.
reference: Second input text.
n_steps: Number of interpolation steps for Integrated Gradients.
Returns:
dict with keys:
- text_to_ref: {tokens, attributions, similarity}
- ref_to_text: {tokens, attributions, similarity}
"""
try:
from captum.attr import LayerIntegratedGradients
except ImportError:
raise ImportError("captum is required for interpret(). Install with: pip install captum")
text_inputs = self.tokenizer(text, return_tensors="pt", padding="max_length",
truncation=True, max_length=self.max_length).to(self.device)
ref_inputs = self.tokenizer(reference, return_tensors="pt", padding="max_length",
truncation=True, max_length=self.max_length).to(self.device)
text_ids = text_inputs["input_ids"]
ref_ids = ref_inputs["input_ids"]
def forward_func(input_ids, ref_emb=None):
out_emb = self.model(input_ids)
return F.cosine_similarity(ref_emb, out_emb, dim=-1)
embedding_layer = self.model.word_embeddings
lig = LayerIntegratedGradients(forward_func, embedding_layer)
# Direction 1: text → reference
with torch.no_grad():
ref_emb = self.model(ref_ids)
text_to_ref = self._attribute_single(text_ids, ref_emb, lig, n_steps)
# Direction 2: reference → text
with torch.no_grad():
text_emb = self.model(text_ids)
ref_to_text = self._attribute_single(ref_ids, text_emb, lig, n_steps)
return {"text_to_ref": text_to_ref, "ref_to_text": ref_to_text}
def visualize(self, text, reference, label=None, output_path="attribution.html", n_steps=50):
"""Run bidirectional interpret() and save an HTML heatmap visualization.
Requires: pip install captum
Args:
text: First input text.
reference: Second input text.
label: Optional ground truth label ("Correct" or "Incorrect").
output_path: Path to save the HTML file.
n_steps: Number of interpolation steps.
Returns:
dict from interpret() with text_to_ref and ref_to_text.
"""
try:
from captum.attr import visualization
except ImportError:
raise ImportError("captum is required for visualize(). Install with: pip install captum")
import numpy as np
result = self.interpret(text, reference, n_steps=n_steps)
vis_records = []
true_label = label or "N/A"
for direction, dir_label in [("text_to_ref", "text → ref"), ("ref_to_text", "ref → text")]:
r = result[direction]
attr_array = np.array(r["attributions"])
norm = np.linalg.norm(attr_array)
if norm > 0:
attr_array = attr_array / norm
sim_score = round(r["similarity"] * 100, 2)
pred_label = "Correct" if r["similarity"] > 0 else "Incorrect"
vis_records.append(visualization.VisualizationDataRecord(
attr_array,
sim_score,
pred_label,
true_label,
true_label,
round(sum(r["attributions"]) * 100, 2),
r["tokens"],
None,
))
vis = visualization.visualize_text(vis_records)
with open(output_path, "w") as f:
f.write(vis.data)
print(f"Saved attribution visualization to {output_path}")
return result