tomaarsen HF Staff commited on
Commit
b67abfb
·
verified ·
1 Parent(s): 62a9793

Create training_nli_v2.py

Browse files
Files changed (1) hide show
  1. training_nli_v2.py +155 -0
training_nli_v2.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The system trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) on the SNLI + MultiNLI (AllNLI) dataset
3
+ with MultipleNegativesRankingLoss. Entailments are positive pairs and the contradiction on AllNLI dataset is added as a hard negative.
4
+ At every 10% training steps, the model is evaluated on the STS benchmark dataset
5
+
6
+ Usage:
7
+ python training_nli_v2.py
8
+
9
+ OR
10
+ python training_nli_v2.py pretrained_transformer_model_name
11
+ """
12
+ import math
13
+ from sentence_transformers import models, losses, datasets
14
+ from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
15
+ from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
16
+ import logging
17
+ from datetime import datetime
18
+ import sys
19
+ import os
20
+ import gzip
21
+ import csv
22
+ import random
23
+
24
+ #### Just some code to print debug information to stdout
25
+ logging.basicConfig(
26
+ format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
27
+ )
28
+ #### /print debug information to stdout
29
+
30
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "microsoft/mpnet-base"
31
+ train_batch_size = 2048 # The larger you select this, the better the results (usually). But it requires more GPU memory
32
+ max_seq_length = 75
33
+ num_epochs = 1
34
+
35
+ # Save path of the model
36
+ model_save_path = (
37
+ "output/training_nli_v2_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
38
+ )
39
+
40
+
41
+ # Here we define our SentenceTransformer model
42
+ word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
43
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean")
44
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
45
+
46
+ # Check if dataset exists. If not, download and extract it
47
+ nli_dataset_path = "data/AllNLI.tsv.gz"
48
+ sts_dataset_path = "data/stsbenchmark.tsv.gz"
49
+
50
+ if not os.path.exists(nli_dataset_path):
51
+ util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path)
52
+
53
+ if not os.path.exists(sts_dataset_path):
54
+ util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path)
55
+
56
+
57
+ # Read the AllNLI.tsv.gz file and create the training dataset
58
+ logging.info("Read AllNLI train dataset")
59
+
60
+
61
+ def add_to_samples(sent1, sent2, label):
62
+ if sent1 not in train_data:
63
+ train_data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()}
64
+ train_data[sent1][label].add(sent2)
65
+
66
+
67
+ train_data = {}
68
+ with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn:
69
+ reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
70
+ for row in reader:
71
+ if row["split"] == "train":
72
+ sent1 = row["sentence1"].strip()
73
+ sent2 = row["sentence2"].strip()
74
+
75
+ add_to_samples(sent1, sent2, row["label"])
76
+ add_to_samples(sent2, sent1, row["label"]) # Also add the opposite
77
+
78
+
79
+ train_samples = []
80
+ for sent1, others in train_data.items():
81
+ if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0:
82
+ train_samples.append(
83
+ InputExample(
84
+ texts=[sent1, random.choice(list(others["entailment"])), random.choice(list(others["contradiction"]))]
85
+ )
86
+ )
87
+ train_samples.append(
88
+ InputExample(
89
+ texts=[random.choice(list(others["entailment"])), sent1, random.choice(list(others["contradiction"]))]
90
+ )
91
+ )
92
+
93
+ logging.info("Train samples: {}".format(len(train_samples)))
94
+
95
+
96
+ # Special data loader that avoid duplicates within a batch
97
+ train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size)
98
+
99
+
100
+ # Our training loss
101
+ train_loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)
102
+
103
+
104
+ # Read STSbenchmark dataset and use it as development set
105
+ logging.info("Read STSbenchmark dev dataset")
106
+ dev_samples = []
107
+ with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
108
+ reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
109
+ for row in reader:
110
+ if row["split"] == "dev":
111
+ score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1
112
+ dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score))
113
+
114
+ dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
115
+ dev_samples, batch_size=train_batch_size, name="sts-dev"
116
+ )
117
+
118
+ # Configure the training
119
+ warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up
120
+ logging.info("Warmup-steps: {}".format(warmup_steps))
121
+
122
+
123
+ # Train the model
124
+ model.fit(
125
+ train_objectives=[(train_dataloader, train_loss)],
126
+ evaluator=dev_evaluator,
127
+ epochs=num_epochs,
128
+ evaluation_steps=int(len(train_dataloader) * 0.1),
129
+ warmup_steps=warmup_steps,
130
+ output_path=model_save_path,
131
+ use_amp=True,
132
+ )
133
+
134
+ # Push the model to the Hugging Face hub
135
+ model.save_to_hub("tomaarsen/mpnet-base-all-nli")
136
+
137
+ ##############################################################################
138
+ #
139
+ # Load the stored model and evaluate its performance on STS benchmark dataset
140
+ #
141
+ ##############################################################################
142
+
143
+ test_samples = []
144
+ with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
145
+ reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
146
+ for row in reader:
147
+ if row["split"] == "test":
148
+ score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1
149
+ test_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score))
150
+
151
+ model = SentenceTransformer(model_save_path)
152
+ test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
153
+ test_samples, batch_size=train_batch_size, name="sts-test"
154
+ )
155
+ test_evaluator(model, output_path=model_save_path)