oist commited on
Commit
09e9a83
·
verified ·
1 Parent(s): 9cfb8ad

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +76 -0
README.md CHANGED
@@ -360,6 +360,82 @@ with torch.inference_mode():
360
  print("Prediction:", pred_class)
361
  # 0 = Entailment, 1 = Neutral, 2 = Contradiction
362
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  ---
365
 
 
360
  print("Prediction:", pred_class)
361
  # 0 = Entailment, 1 = Neutral, 2 = Contradiction
362
  ```
363
+ ### Example 4: Using BLASER Semantic Score with MMNLI
364
+
365
+ You can use the BLASER semantic score in combination with the MMNLI NLI class to get a **better understanding of the relationship** between source and candidate translations. The NLI class gives the entailment/contradiction/neutral label, while the BLASER score provides a fine-grained semantic similarity.
366
+
367
+ ```python
368
+
369
+ import torch
370
+ from transformers import AutoTokenizer, AutoModel
371
+ from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
372
+
373
+ # -------------------------
374
+ # 1️⃣ Load ported SONAR text encoder
375
+ # -------------------------
376
+ sonar_model_name = "cointegrated/SONAR_200_text_encoder"
377
+ encoder = M2M100Encoder.from_pretrained(sonar_model_name)
378
+ tokenizer = AutoTokenizer.from_pretrained(sonar_model_name)
379
+
380
+ def encode_mean_pool(texts, tokenizer, encoder, lang='eng_Latn', norm=False):
381
+ tokenizer.src_lang = lang
382
+ with torch.inference_mode():
383
+ batch = tokenizer(texts, return_tensors='pt', padding=True)
384
+ seq_embs = encoder(**batch).last_hidden_state
385
+ mask = batch.attention_mask
386
+ mean_emb = (seq_embs * mask.unsqueeze(-1)).sum(1) / mask.unsqueeze(-1).sum(1)
387
+ if norm:
388
+ mean_emb = torch.nn.functional.normalize(mean_emb)
389
+ return mean_emb
390
+
391
+ # -------------------------
392
+ # 2️⃣ Example sentences
393
+ # -------------------------
394
+ src_sentence = ["He is happy."]
395
+ mt_sentences = [
396
+ "Il est content.", # entailment blaser:4.515
397
+ "Il est malheureux." # contradiction blaser: 4.41
398
+ ]
399
+
400
+ # Encode source and MT sentences
401
+ src_embs = encode_mean_pool(src_sentence, tokenizer, encoder, lang="eng_Latn")
402
+ mt_embs = encode_mean_pool(mt_sentences, tokenizer, encoder, lang="fra_Latn")
403
+
404
+ # -------------------------
405
+ # 3️⃣ Load MMNLI model
406
+ # -------------------------
407
+ mmnli_model_name = "oist/multimodal_nli_model"
408
+ mmnli_model = AutoModel.from_pretrained(mmnli_model_name, trust_remote_code=True)
409
+ mmnli_model.eval()
410
+
411
+ # -------------------------
412
+ # 4️⃣ Load BLASER QE model
413
+ # -------------------------
414
+ qe_model_name = "oist/blaser_2_0_qe_ported"
415
+ qe_model = AutoModel.from_pretrained(qe_model_name, trust_remote_code=True)
416
+ qe_model.eval()
417
+
418
+ # -------------------------
419
+ # 5️⃣ Run inference
420
+ # -------------------------
421
+ for i, mt_sentence in enumerate(mt_sentences):
422
+ mt_emb = mt_embs[i].unsqueeze(0) # keep batch dimension
423
+
424
+ # NLI prediction
425
+ with torch.inference_mode():
426
+ logits = mmnli_model(src_embs, mt_emb)
427
+ pred_class = torch.argmax(logits, dim=-1).item()
428
+
429
+ # BLASER semantic score
430
+ with torch.inference_mode():
431
+ qe_score = qe_model(src_embs, mt_emb) # shape [1, 1]
432
+
433
+ print(f"\nMT sentence: '{mt_sentence}'")
434
+ print("NLI prediction:", ["Entailment", "Neutral", "Contradiction"][pred_class])
435
+ print("BLASER semantic score:", qe_score.item())
436
+
437
+ ```
438
+
439
 
440
  ---
441