Update README.md
Browse files
README.md
CHANGED
|
@@ -360,6 +360,82 @@ with torch.inference_mode():
|
|
| 360 |
print("Prediction:", pred_class)
|
| 361 |
# 0 = Entailment, 1 = Neutral, 2 = Contradiction
|
| 362 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
---
|
| 365 |
|
|
|
|
| 360 |
print("Prediction:", pred_class)
|
| 361 |
# 0 = Entailment, 1 = Neutral, 2 = Contradiction
|
| 362 |
```
|
| 363 |
+
### Example 4: Using BLASER Semantic Score with MMNLI
|
| 364 |
+
|
| 365 |
+
You can use the BLASER semantic score in combination with the MMNLI NLI class to get a **better understanding of the relationship** between source and candidate translations. The NLI class gives the entailment/contradiction/neutral label, while the BLASER score provides a fine-grained semantic similarity.
|
| 366 |
+
|
| 367 |
+
```python
|
| 368 |
+
|
| 369 |
+
import torch
|
| 370 |
+
from transformers import AutoTokenizer, AutoModel
|
| 371 |
+
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
|
| 372 |
+
|
| 373 |
+
# -------------------------
|
| 374 |
+
# 1️⃣ Load ported SONAR text encoder
|
| 375 |
+
# -------------------------
|
| 376 |
+
sonar_model_name = "cointegrated/SONAR_200_text_encoder"
|
| 377 |
+
encoder = M2M100Encoder.from_pretrained(sonar_model_name)
|
| 378 |
+
tokenizer = AutoTokenizer.from_pretrained(sonar_model_name)
|
| 379 |
+
|
| 380 |
+
def encode_mean_pool(texts, tokenizer, encoder, lang='eng_Latn', norm=False):
|
| 381 |
+
tokenizer.src_lang = lang
|
| 382 |
+
with torch.inference_mode():
|
| 383 |
+
batch = tokenizer(texts, return_tensors='pt', padding=True)
|
| 384 |
+
seq_embs = encoder(**batch).last_hidden_state
|
| 385 |
+
mask = batch.attention_mask
|
| 386 |
+
mean_emb = (seq_embs * mask.unsqueeze(-1)).sum(1) / mask.unsqueeze(-1).sum(1)
|
| 387 |
+
if norm:
|
| 388 |
+
mean_emb = torch.nn.functional.normalize(mean_emb)
|
| 389 |
+
return mean_emb
|
| 390 |
+
|
| 391 |
+
# -------------------------
|
| 392 |
+
# 2️⃣ Example sentences
|
| 393 |
+
# -------------------------
|
| 394 |
+
src_sentence = ["He is happy."]
|
| 395 |
+
mt_sentences = [
|
| 396 |
+
"Il est content.", # entailment blaser:4.515
|
| 397 |
+
"Il est malheureux." # contradiction blaser: 4.41
|
| 398 |
+
]
|
| 399 |
+
|
| 400 |
+
# Encode source and MT sentences
|
| 401 |
+
src_embs = encode_mean_pool(src_sentence, tokenizer, encoder, lang="eng_Latn")
|
| 402 |
+
mt_embs = encode_mean_pool(mt_sentences, tokenizer, encoder, lang="fra_Latn")
|
| 403 |
+
|
| 404 |
+
# -------------------------
|
| 405 |
+
# 3️⃣ Load MMNLI model
|
| 406 |
+
# -------------------------
|
| 407 |
+
mmnli_model_name = "oist/multimodal_nli_model"
|
| 408 |
+
mmnli_model = AutoModel.from_pretrained(mmnli_model_name, trust_remote_code=True)
|
| 409 |
+
mmnli_model.eval()
|
| 410 |
+
|
| 411 |
+
# -------------------------
|
| 412 |
+
# 4️⃣ Load BLASER QE model
|
| 413 |
+
# -------------------------
|
| 414 |
+
qe_model_name = "oist/blaser_2_0_qe_ported"
|
| 415 |
+
qe_model = AutoModel.from_pretrained(qe_model_name, trust_remote_code=True)
|
| 416 |
+
qe_model.eval()
|
| 417 |
+
|
| 418 |
+
# -------------------------
|
| 419 |
+
# 5️⃣ Run inference
|
| 420 |
+
# -------------------------
|
| 421 |
+
for i, mt_sentence in enumerate(mt_sentences):
|
| 422 |
+
mt_emb = mt_embs[i].unsqueeze(0) # keep batch dimension
|
| 423 |
+
|
| 424 |
+
# NLI prediction
|
| 425 |
+
with torch.inference_mode():
|
| 426 |
+
logits = mmnli_model(src_embs, mt_emb)
|
| 427 |
+
pred_class = torch.argmax(logits, dim=-1).item()
|
| 428 |
+
|
| 429 |
+
# BLASER semantic score
|
| 430 |
+
with torch.inference_mode():
|
| 431 |
+
qe_score = qe_model(src_embs, mt_emb) # shape [1, 1]
|
| 432 |
+
|
| 433 |
+
print(f"\nMT sentence: '{mt_sentence}'")
|
| 434 |
+
print("NLI prediction:", ["Entailment", "Neutral", "Contradiction"][pred_class])
|
| 435 |
+
print("BLASER semantic score:", qe_score.item())
|
| 436 |
+
|
| 437 |
+
```
|
| 438 |
+
|
| 439 |
|
| 440 |
---
|
| 441 |
|