| |
| |
| """ |
| Created on Sun Aug 13 20:57:28 2023 |
| |
| @author: fujidai |
| """ |
|
|
|
|
| import torch |
| from sentence_transformers import SentenceTransformer, InputExample, losses,models |
| from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses |
| from sentence_transformers.readers import InputExample |
| from torch.utils.data import DataLoader |
| from transformers import AutoTokenizer |
| from sentence_transformers.SentenceTransformer import SentenceTransformer |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from sentence_transformers import SentenceTransformer, util |
|
|
|
|
| |
| word_embedding_model = models.Transformer('/paraphrase-mpnet-base-v2', max_seq_length=512) |
|
|
| pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
| |
| model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps') |
| print(model) |
|
|
|
|
| with open('/WMT_da_学習データ_88993文/en-label-正規化.txt', 'r') as f: |
|
|
| raberu = f.read() |
| raberu_lines = raberu.splitlines() |
| data = [] |
| for i in range(len(raberu_lines)): |
| data.append(float(raberu_lines[i])) |
|
|
|
|
|
|
|
|
| with open('/WMT_da_学習データ_88993文/en-origin.txt', 'r') as f: |
| left = f.read() |
| left_lines = left.splitlines() |
|
|
| with open('/WMT_da_学習データ_88993文/en-pseudo.txt', 'r') as f: |
| senter = f.read() |
| senter_lines = senter.splitlines() |
|
|
| with open('/WMT_da_学習データ_88993文/en-pseudo-pseudo.txt', 'r') as f: |
| right = f.read() |
| right_lines = right.splitlines() |
|
|
|
|
| train_examples = [] |
| for i in range(len(left_lines)): |
| pair=[] |
| pair.append(left_lines[i]) |
| pair.append(senter_lines[i]) |
| pair.append(right_lines[i]) |
|
|
| example = InputExample(texts=pair, label=data[i]) |
| |
| |
| |
| train_examples.append(example) |
| print(len(train_examples)) |
|
|
|
|
| device = torch.device('mps') |
| |
|
|
| import torch.nn.functional as F |
|
|
|
|
| train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8) |
| |
| train_loss = losses.CosineSimilarityLoss(model) |
|
|
|
|
| |
| model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=100, warmup_steps=100,show_progress_bar=True, |
| |
| checkpoint_path='checkpoint_save_name',checkpoint_save_steps=11125, |
| save_best_model=True, |
| |
|
|
| ) |
| model.save("last_save_name") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
|
|