Update modeling_retrieva_bert.py
Browse files
modeling_retrieva_bert.py
CHANGED
|
@@ -65,7 +65,7 @@ from .configuration_retrieva_bert import RetrievaBertConfig
|
|
| 65 |
logger = logging.get_logger(__name__)
|
| 66 |
|
| 67 |
_CONFIG_FOR_DOC = "RetrievaBertConfig"
|
| 68 |
-
_CHECKPOINT_FOR_DOC = "
|
| 69 |
|
| 70 |
|
| 71 |
def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
|
|
@@ -1170,8 +1170,8 @@ class RetrievaBertForPreTraining(RetrievaBertPreTrainedModel):
|
|
| 1170 |
>>> from models import RetrievaBertForPreTraining
|
| 1171 |
>>> import torch
|
| 1172 |
|
| 1173 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
| 1174 |
-
>>> model = RetrievaBertForPreTraining.from_pretrained("
|
| 1175 |
|
| 1176 |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1177 |
>>> outputs = model(**inputs)
|
|
@@ -1294,8 +1294,8 @@ class RetrievaBertForCausalLM(RetrievaBertPreTrainedModel):
|
|
| 1294 |
>>> from models import RetrievaBertForCausalLM, RetrievaBertConfig
|
| 1295 |
>>> import torch
|
| 1296 |
|
| 1297 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
| 1298 |
-
>>> model = RetrievaBertForCausalLM.from_pretrained("
|
| 1299 |
|
| 1300 |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1301 |
>>> outputs = model(**inputs)
|
|
@@ -1528,8 +1528,8 @@ class RetrievaBertForNextSentencePrediction(RetrievaBertPreTrainedModel):
|
|
| 1528 |
>>> from models import RetrievaBertForNextSentencePrediction
|
| 1529 |
>>> import torch
|
| 1530 |
|
| 1531 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
| 1532 |
-
>>> model = RetrievaBertForNextSentencePrediction.from_pretrained("
|
| 1533 |
|
| 1534 |
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1535 |
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
|
|
|
| 65 |
logger = logging.get_logger(__name__)
|
| 66 |
|
| 67 |
_CONFIG_FOR_DOC = "RetrievaBertConfig"
|
| 68 |
+
_CHECKPOINT_FOR_DOC = "retrieva-jp/bert-1.3b"
|
| 69 |
|
| 70 |
|
| 71 |
def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
|
|
|
|
| 1170 |
>>> from models import RetrievaBertForPreTraining
|
| 1171 |
>>> import torch
|
| 1172 |
|
| 1173 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
|
| 1174 |
+
>>> model = RetrievaBertForPreTraining.from_pretrained("retrieva-jp/bert-1.3b")
|
| 1175 |
|
| 1176 |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1177 |
>>> outputs = model(**inputs)
|
|
|
|
| 1294 |
>>> from models import RetrievaBertForCausalLM, RetrievaBertConfig
|
| 1295 |
>>> import torch
|
| 1296 |
|
| 1297 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
|
| 1298 |
+
>>> model = RetrievaBertForCausalLM.from_pretrained("retrieva-jp/bert-1.3b", is_decoder=True)
|
| 1299 |
|
| 1300 |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1301 |
>>> outputs = model(**inputs)
|
|
|
|
| 1528 |
>>> from models import RetrievaBertForNextSentencePrediction
|
| 1529 |
>>> import torch
|
| 1530 |
|
| 1531 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
|
| 1532 |
+
>>> model = RetrievaBertForNextSentencePrediction.from_pretrained("retrieva-jp/bert-1.3b")
|
| 1533 |
|
| 1534 |
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1535 |
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|