File size: 1,385 Bytes
82c0c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from transformers import PreTrainedModel
from transformers import AutoModelForMaskedLM, AutoTokenizer
from pytorch_lightning.loggers import WandbLogger

from src.regression.PL import FullModelPL, EncoderPL, DecoderPL
from src.regression.HF.configs import FullModelConfigHF

from config import DEVICE


class FullModelHF(PreTrainedModel):
    config_class = FullModelConfigHF

    def __init__(self, config):

        super().__init__(config)

        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_ckpt)
        mlm_bert = AutoModelForMaskedLM.from_pretrained(config.bert_ckpt)
        self.bert = mlm_bert.distilbert

        encoder = EncoderPL(tokenizer=self.tokenizer, bert=self.bert).to(DEVICE)

        wandb_logger = WandbLogger(
            project="transformers",
            entity="sanjin_juric_fot",
            # log_model=True,
            # reinit=True,
        )

        artifact = wandb_logger.use_artifact(config.decoder_ckpt)
        artifact_dir = artifact.download()
        decoder = DecoderPL.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to(DEVICE)

        self.model = FullModelPL(
            encoder=encoder,
            decoder=decoder,
            layer_norm=config.layer_norm,
            nontext_features=config.nontext_features,
        ).to(DEVICE)

    def forward(self, input):
        return self.model._get_loss(input)