| |
| |
| """ |
| Created on Sat Jun 17 16:20:22 2023 |
| |
| @author: fujidai |
| """ |
|
|
|
|
| from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses |
| import torch |
| from torch.utils.data import DataLoader |
| from sentence_transformers.datasets import ParallelSentencesDataset |
| from datetime import datetime |
|
|
| import os |
| import logging |
| import sentence_transformers.util |
| import csv |
| import gzip |
| from tqdm.autonotebook import tqdm |
| import numpy as np |
| import zipfile |
| import io |
|
|
| logging.basicConfig(format='%(asctime)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| level=logging.INFO, |
| handlers=[LoggingHandler()]) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| teacher_model_name = 'teacher_finetune.py で作成したモデル' |
| |
|
|
| student_model_name = '完成2-MarginMSELoss-finetuning-6-30' |
| |
|
|
|
|
| max_seq_length = 128 |
| train_batch_size = 64 |
| inference_batch_size = 64 |
| max_sentences_per_language = 500000 |
| train_max_sentence_length = 250 |
|
|
| num_epochs = 100 |
| num_warmup_steps = 10000 |
|
|
| num_evaluation_steps = 1000 |
| dev_sentences = 1000 |
|
|
|
|
| |
| logger.info("Load teacher model") |
| teacher_model = SentenceTransformer(teacher_model_name,device='mps') |
|
|
|
|
| logger.info("Create student model from scratch") |
|
|
| word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length) |
| |
| pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
| student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps') |
|
|
| print(teacher_model) |
| print(student_model) |
|
|
|
|
| from sentence_transformers.datasets import ParallelSentencesDataset |
|
|
| train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model) |
| train_data.load_data('output-100000-karanasi.txt') |
| |
|
|
| |
| train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size) |
| train_loss = losses.MSELoss(model=student_model) |
|
|
| print(train_data) |
|
|
|
|
| |
|
|
| |
| print('az') |
| student_model.fit(train_objectives=[(train_dataloader, train_loss)], |
| epochs=num_epochs, |
| |
| warmup_steps=num_warmup_steps, |
| evaluation_steps=num_evaluation_steps, |
| |
| |
| optimizer_params= {'lr': 2e-5, 'eps': 1e-6}, |
| checkpoint_path='checkpoint_savename', |
| checkpoint_save_steps=2344 |
| ) |
|
|
| student_model.save('savename') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
|
|