| | import torch.nn as nn |
| | import torch |
| | from transformers import AutoTokenizer, BertForSequenceClassification, PreTrainedModel, PretrainedConfig, AutoModelForQuestionAnswering, get_scheduler |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| | from torch.nn import CrossEntropyLoss |
| | from torch.optim import AdamW |
| | from LUKE_pipe import generate |
| | from datasets import load_dataset |
| | from accelerate import Accelerator |
| | from tqdm import tqdm |
| |
|
| | MAX_BEAM = 10 |
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
|
| | class ClassifierAdapter(nn.Module): |
| | def __init__(self, l1=3): |
| | super().__init__() |
| | self.linear1 = nn.Linear(l1, 1) |
| | self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| | self.bert = BertForSequenceClassification.from_pretrained("botcon/right_span_bert") |
| | self.relu = nn.ReLU() |
| |
|
| | def forward(self, questions, answers, logits): |
| | beam_size = len(answers[0]) |
| | samples = len(questions) |
| | questions = [question for _ in range(len(answers[0])) for question in questions] |
| | answers = [answer for beam in answers for answer in beam] |
| | input = self.tokenizer( |
| | questions, |
| | answers, |
| | padding="max_length", |
| | return_tensors="pt" |
| | ).to(device) |
| | bert_logits = self.bert(**input).logits |
| | bert_logits = bert_logits.reshape(samples, beam_size, 2) |
| | logits = torch.FloatTensor(logits).to(device).unsqueeze(-1) |
| | logits = torch.cat((logits, bert_logits), dim=-1) |
| | logits = self.relu(logits) |
| | out = torch.squeeze(self.linear1(logits), dim=-1) |
| | return out |
| |
|
| | class HuggingWrapper(PreTrainedModel): |
| | config_class = PretrainedConfig() |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.model = ClassifierAdapter() |
| |
|
| | def forward(self, **kwargs): |
| | labels = kwargs.pop("labels") |
| | output = self.model(**kwargs) |
| | loss_fn = CrossEntropyLoss(ignore_index=MAX_BEAM) |
| | loss = loss_fn(output, labels) |
| | return SequenceClassifierOutput(logits=output, loss=loss) |
| | |
| | accelerator = Accelerator(mixed_precision="fp16") |
| | model = HuggingWrapper.from_pretrained("botcon/special_bert").to(device) |
| | optimizer = AdamW(model.parameters()) |
| | model, optimizer = accelerator.prepare(model, optimizer) |
| | batch_size = 2 |
| | raw_datasets = load_dataset("squad") |
| | raw_train = raw_datasets["train"] |
| | num_updates = len(raw_train) // batch_size |
| | num_epoch = 2 |
| | num_training_steps = num_updates * num_epoch |
| | lr_scheduler = get_scheduler( |
| | "linear", |
| | optimizer=optimizer, |
| | num_warmup_steps=0, |
| | num_training_steps=num_training_steps, |
| | ) |
| |
|
| | progress_bar = tqdm(range(num_training_steps)) |
| |
|
| | for epoch in range(num_epoch): |
| | start = 0 |
| | end = batch_size |
| | steps = 0 |
| | cumu_loss = 0 |
| | training_data = raw_train |
| | model.train() |
| | while start < len(training_data): |
| | optimizer.zero_grad() |
| | batch_data = raw_train.select(range(start, min(end, len(raw_train)))) |
| | with torch.no_grad(): |
| | res = generate(batch_data) |
| | prediction = [] |
| | predicted_logit = [] |
| | labels = [] |
| | for i in range(len(res)): |
| | x = res[i] |
| | ground_answer = batch_data["answers"][i]["text"][0] |
| | predicted_text = x["prediction_text"] |
| | found = False |
| | for k in range(len(predicted_text)): |
| | if predicted_text[k] == ground_answer: |
| | labels.append(k) |
| | found = True |
| | break |
| | if not found: |
| | labels.append(MAX_BEAM) |
| | prediction.append(predicted_text) |
| | predicted_logit.append(x["logits"]) |
| | labels = torch.LongTensor(labels).to(device) |
| | classifier_out = model(questions=batch_data["question"] , answers=prediction, logits=predicted_logit, labels=labels) |
| | loss = classifier_out.loss |
| | if not torch.isnan(loss).item(): |
| | cumu_loss += loss.item() |
| | steps += 1 |
| | accelerator.backward(loss) |
| | optimizer.step() |
| | lr_scheduler.step() |
| | progress_bar.update(1) |
| | start += batch_size |
| | end += batch_size |
| | |
| | if steps % 100 == 0: |
| | print("Cumu loss: {}".format(cumu_loss / 100)) |
| | cumu_loss = 0 |
| |
|
| | model.push_to_hub("some_fake_bert") |