| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| from pathlib import Path |
| import json |
| import tempfile |
|
|
| from transformers import MT5Tokenizer, MT5TokenizerFast, MT5Config, MT5ForConditionalGeneration |
| from transformers.models.t5.tokenization_t5 import VOCAB_FILES_NAMES |
|
|
| mname_from = "google/mt5-small" |
| mname_very_small = "mt5-tiny-random" |
|
|
| tokenizer = MT5Tokenizer.from_pretrained(mname_from) |
| config = MT5Config.from_pretrained(mname_from) |
| |
|
|
| |
| import sys |
| |
| sys.path.append("./sentencepiece/python/src/sentencepiece") |
| import sentencepiece_model_pb2 as model |
|
|
| tmp_dir = "/tmp/mt5-small" |
| tokenizer.save_pretrained(tmp_dir) |
| file = tmp_dir + "/spiece.model" |
| with open(file, 'rb') as f: data = f.read() |
|
|
| |
| m = model.ModelProto() |
| m.ParseFromString(data) |
|
|
| keep_items = 5000 |
|
|
| print("Shrinking vocab") |
| print(f"original dict {len(m.pieces)}") |
| for i in range(len(m.pieces)-keep_items): _ = m.pieces.pop() |
| print(f"new dict {len(m.pieces)}") |
|
|
| with open(tmp_dir + "/spiece-short.model", 'wb') as f: |
| f.write(m.SerializeToString()) |
|
|
| tokenizer = MT5Tokenizer(vocab_file=tmp_dir + "/spiece-short.model") |
|
|
| config.update(dict( |
| vocab_size=keep_items+12, |
| d_model=64, |
| d_ff=256, |
| d_kv=8, |
| num_layers=8, |
| num_decoder_layers=8, |
| num_heads=4, |
| relative_attention_num_buckets=32, |
| )) |
| print("new config", config) |
|
|
| very_small_model = MT5ForConditionalGeneration(config) |
| print(f"num of params {very_small_model.num_parameters()}") |
| very_small_model.resize_token_embeddings(len(tokenizer)) |
|
|
| |
| src_texts = ["A long paragraph for summarization.", "Another paragraph for summarization."] |
| tgt_texts = ["Summary of the text.", "Another summary."] |
|
|
| batch = tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts, return_tensors="pt") |
| outputs = very_small_model(**batch) |
|
|
| print("test output:", len(outputs.logits[0])) |
|
|
| |
| very_small_model.half() |
| very_small_model.save_pretrained(mname_very_small) |
| config.save_pretrained(mname_very_small) |
| tokenizer.save_pretrained(mname_very_small) |
| |
|
|
| print(f"Generated {mname_very_small}") |
|
|
| |
| |
| |
|
|