| |
| import torch |
| from transformers import ( |
| BertTokenizer, |
| BertForMaskedLM, |
| AutoModelForMaskedLM, |
| AutoTokenizer, |
| BertModel, |
| ) |
| import numpy as np |
| import random |
| from itertools import islice |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW, SGD |
| from tqdm import tqdm |
| import os |
|
|
|
|
| def index_to_onehot(l, length): |
| |
| return [1 if i in l else 0 for i in range(length)] |
|
|
|
|
| def get_punctuation_position(tokenized_text, tokenizer): |
| |
| count = 0 |
| comma_pos = [] |
| period_pos = [] |
| punctuation_removed_text = [] |
| comma_id = tokenizer.convert_tokens_to_ids("、") |
| period_id = tokenizer.convert_tokens_to_ids("。") |
|
|
| for i, c in enumerate(tokenized_text): |
| if c == comma_id: |
| comma_pos.append(i - count - 1) |
| count += 1 |
| elif c == period_id: |
| period_pos.append(i - count - 1) |
| count += 1 |
| else: |
| punctuation_removed_text.append(c) |
|
|
| if len(punctuation_removed_text) < 512: |
| punctuation_removed_text += [tokenizer.pad_token_id] * ( |
| 512 - len(punctuation_removed_text) |
| ) |
|
|
| return ( |
| torch.tensor(punctuation_removed_text), |
| [ |
| index_to_onehot(comma_pos, 512), |
| index_to_onehot(period_pos, 512), |
| ], |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| class PunctuationPositionDataset(torch.utils.data.Dataset): |
| def __init__(self, data, tokenizer): |
| self.data = data |
| self.tokenizer = tokenizer |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| text = self.data[idx] |
| text = " ".join(list(text)) |
| inputs = self.tokenizer( |
| text, |
| max_length=512, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ) |
| |
| |
| input_ids, label = get_punctuation_position( |
| inputs["input_ids"][0], self.tokenizer |
| ) |
|
|
| label = torch.tensor(label, dtype=torch.float32).transpose(0, 1) |
|
|
| return (input_ids, inputs.attention_mask.squeeze(), label.squeeze(), text) |
|
|
|
|
| |
| model_name = "tohoku-nlp/bert-base-japanese-char-v3" |
| tokenizer = BertTokenizer.from_pretrained(model_name) |
| base_model = BertModel.from_pretrained(model_name) |
|
|
|
|
| |
| class punctuation_predictor(torch.nn.Module): |
| def __init__(self, base_model): |
| super().__init__() |
| self.base_model = base_model |
| self.dropout = torch.nn.Dropout(0.2) |
| self.linear = torch.nn.Linear(768, 2) |
|
|
| def forward(self, input_ids, attention_mask): |
| last_hidden_state = self.base_model( |
| input_ids=input_ids, attention_mask=attention_mask |
| ).last_hidden_state |
| |
| return self.linear(self.dropout(last_hidden_state)) |
|
|
|
|
| model = punctuation_predictor(base_model) |
| |
| |
| |
| |
| |
| |
| with open("data/train.txt", "r") as f: |
| texts = f.readlines() |
|
|
| dataset = PunctuationPositionDataset(texts, tokenizer) |
| |
| data_loader = DataLoader( |
| dataset, |
| batch_size=16, |
| shuffle=True, |
| num_workers=8, |
| ) |
| |
| |
|
|
| optimizer = AdamW( |
| [ |
| {"params": model.base_model.parameters(), "lr": 5e-5}, |
| {"params": model.linear.parameters(), "lr": 1e-3}, |
| ], |
| ) |
|
|
| criteria = torch.nn.BCEWithLogitsLoss() |
| |
| model.train() |
| model.to("cuda") |
| for epoch in range(10): |
| epoch_loss = 0.0 |
| progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}") |
| for batch in progress_bar: |
| input_ids, attention_masks, labels, text = batch |
| input_ids = input_ids.to("cuda") |
| attention_masks = attention_masks.to("cuda") |
| labels = labels.to("cuda") |
|
|
| outputs = model(input_ids=input_ids, attention_mask=attention_masks) |
| loss = criteria(outputs, labels) |
|
|
| loss.backward() |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| epoch_loss += loss.item() |
| progress_bar.set_postfix({"loss": epoch_loss / len(data_loader)}) |
| |
| torch.save(model.state_dict(), "weight/punctuation_position_model.pth") |
|
|