F-Haru commited on
Commit
d447698
·
1 Parent(s): 837ef60

Delete distillation.py

Browse files
Files changed (1) hide show
  1. distillation.py +0 -106
distillation.py DELETED
@@ -1,106 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Sat Jun 17 16:20:22 2023
5
-
6
- @author: fujidai
7
- """
8
-
9
-
10
- from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses
11
- import torch
12
- from torch.utils.data import DataLoader
13
- from sentence_transformers.datasets import ParallelSentencesDataset
14
- from datetime import datetime
15
-
16
- import os
17
- import logging
18
- import sentence_transformers.util
19
- import csv
20
- import gzip
21
- from tqdm.autonotebook import tqdm
22
- import numpy as np
23
- import zipfile
24
- import io
25
-
26
- logging.basicConfig(format='%(asctime)s - %(message)s',
27
- datefmt='%Y-%m-%d %H:%M:%S',
28
- level=logging.INFO,
29
- handlers=[LoggingHandler()])
30
- logger = logging.getLogger(__name__)
31
-
32
-
33
- teacher_model_name = 'TED-finetuning_teacher.py で作成した教師モデル' #Our monolingual teacher model, we want to convert to multiple languages
34
-
35
- student_model_name = 'TED-finetuning_student.py で作成した生徒モデル' #Multilingual base model we use to imitate the teacher model
36
-
37
-
38
- max_seq_length = 128 #Student model max. lengths for inputs (number of word pieces)
39
- train_batch_size = 64 #Batch size for training
40
- inference_batch_size = 64 #Batch size at inference
41
- max_sentences_per_language = 500000 #Maximum number of parallel sentences for training
42
- train_max_sentence_length = 250 #Maximum length (characters) for parallel training sentences
43
-
44
- num_epochs = 100 #Train for x epochs
45
- num_warmup_steps = 10000 #Warumup steps
46
-
47
- num_evaluation_steps = 1000 #Evaluate performance after every xxxx steps
48
- dev_sentences = 1000 #Number of parallel sentences to be used for development
49
-
50
-
51
- ######## Start the extension of the teacher model to multiple languages ########
52
- logger.info("Load teacher model")
53
- teacher_model = SentenceTransformer(teacher_model_name,device='mps')
54
-
55
-
56
- logger.info("Create student model from scratch")
57
-
58
- word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length)
59
- # Apply mean pooling to get one fixed sized sentence vector
60
- pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())#denseで次元数を768にする次元数をいじる
61
- student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device='mps')
62
-
63
- print(teacher_model)
64
- print(student_model)
65
-
66
-
67
- from sentence_transformers.datasets import ParallelSentencesDataset
68
-
69
- train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model)
70
- train_data.load_data('/en-other.txt')# 英語 タブ 他の言語 というようになっている文
71
-
72
-
73
- #train_data.load_data('/Users/fujidai/TED2020_data/data/tuikazumi/en-ja/TED2020.en-ja.en')
74
- train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
75
- train_loss = losses.MSELoss(model=student_model)
76
-
77
- print(train_data)
78
-
79
-
80
- #50000_all-MiniLM-L6-v2__paraphrase-distilroberta-base-v2_epoch-1
81
-
82
- # Train the model
83
- print('az')
84
- student_model.fit(train_objectives=[(train_dataloader, train_loss)],
85
- epochs=num_epochs,
86
- #device=device,
87
- warmup_steps=num_warmup_steps,
88
- evaluation_steps=num_evaluation_steps,
89
- optimizer_params= {'lr': 2e-5, 'eps': 1e-6},
90
- checkpoint_path='checkpoint-savename',
91
- checkpoint_save_steps=2000#その時に応じて変更する
92
- )
93
-
94
- student_model.save('savename')
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
-
106
- #