| | |
| | |
| | """ |
| | Created on Sat Jun 17 16:20:22 2023 |
| | |
| | @author: fujidai |
| | """ |
| |
|
| |
|
| | from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from sentence_transformers.datasets import ParallelSentencesDataset |
| | from datetime import datetime |
| |
|
| | import os |
| | import logging |
| | import sentence_transformers.util |
| | import csv |
| | import gzip |
| | from tqdm.autonotebook import tqdm |
| | import numpy as np |
| | import zipfile |
| | import io |
| |
|
| | logging.basicConfig(format='%(asctime)s - %(message)s', |
| | datefmt='%Y-%m-%d %H:%M:%S', |
| | level=logging.INFO, |
| | handlers=[LoggingHandler()]) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | teacher_model_name = 'TED-finetuning_teacher.py で作成した教師モデル' |
| |
|
| | student_model_name = 'TED-finetuning_student.py で作成した生徒モデル' |
| |
|
| |
|
| | max_seq_length = 128 |
| | train_batch_size = 64 |
| | inference_batch_size = 64 |
| | max_sentences_per_language = 500000 |
| | train_max_sentence_length = 250 |
| |
|
| | num_epochs = 100 |
| | num_warmup_steps = 10000 |
| |
|
| | num_evaluation_steps = 1000 |
| | dev_sentences = 1000 |
| |
|
| |
|
| | |
| | logger.info("Load teacher model") |
| | teacher_model = SentenceTransformer(teacher_model_name,device='mps') |
| |
|
| |
|
| | logger.info("Create student model from scratch") |
| |
|
| | word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length) |
| | |
| | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
| | student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps') |
| |
|
| | print(teacher_model) |
| | print(student_model) |
| |
|
| |
|
| | from sentence_transformers.datasets import ParallelSentencesDataset |
| |
|
| | train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model) |
| | train_data.load_data('/en-other.txt') |
| |
|
| |
|
| | |
| | train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size) |
| | train_loss = losses.MSELoss(model=student_model) |
| |
|
| | print(train_data) |
| |
|
| |
|
| | |
| |
|
| | |
| | print('az') |
| | student_model.fit(train_objectives=[(train_dataloader, train_loss)], |
| | epochs=num_epochs, |
| | |
| | warmup_steps=num_warmup_steps, |
| | evaluation_steps=num_evaluation_steps, |
| | optimizer_params= {'lr': 2e-5, 'eps': 1e-6}, |
| | checkpoint_path='checkpoint-savename', |
| | checkpoint_save_steps=2000 |
| | ) |
| |
|
| | student_model.save('savename') |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| |
|