|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
library_name: transformers |
|
|
pipeline_tag: sentence-similarity |
|
|
tags: |
|
|
- sentence-embeddings |
|
|
- retrieval |
|
|
- contrastive-learning |
|
|
- bert |
|
|
base_model: bert-base-uncased |
|
|
datasets: |
|
|
- nyu-mll/multi_nli |
|
|
metrics: |
|
|
- cosine |
|
|
--- |
|
|
|
|
|
# Vectra |
|
|
|
|
|
BERT-base sentence embeddings trained with in-batch contrastive learning (Multiple Negatives Ranking Loss) on MultiNLI entailment pairs. |
|
|
|
|
|
## Model |
|
|
|
|
|
- Base: `bert-base-uncased` |
|
|
- Pooling: mean pooling over token embeddings (masked) |
|
|
- Normalization: L2 |
|
|
- Objective: MNRL / InfoNCE-style softmax with temperature 0.05 |
|
|
- Training data: MultiNLI entailment pairs (subset) |
|
|
|
|
|
## Usage (embeddings) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
def mean_pooling(last_hidden_state, attention_mask): |
|
|
mask = attention_mask.unsqueeze(-1).to(dtype=last_hidden_state.dtype) |
|
|
summed = (last_hidden_state * mask).sum(dim=1) |
|
|
counts = mask.sum(dim=1).clamp(min=1e-6) |
|
|
return summed / counts |
|
|
|
|
|
@torch.no_grad() |
|
|
def embed_texts(texts, model_id="rafidka/vectra", max_length=128, device="cuda"): |
|
|
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True) |
|
|
model = AutoModel.from_pretrained(model_id, add_pooling_layer=False).to(device).eval() |
|
|
batch = tok(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").to(device) |
|
|
out = model(**batch) |
|
|
emb = mean_pooling(out.last_hidden_state, batch["attention_mask"]) |
|
|
emb = F.normalize(emb, p=2, dim=-1) |
|
|
return emb |
|
|
``` |
|
|
|