KoJaDiffuser / koja_diffuser /train /merge_model.py
github-actions[bot]
Sync from GitHub 33c12db74322f3d28409b5dc0a8c441914c9178b
e0552b0
import torch
from koja_diffuser.util import file_to_tensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ko_parquet = file_to_tensor("./dist/ko_token.parquet")
ja_parquet = file_to_tensor("./dist/ja_token.parquet")
ko_ckpt = torch.load("./dist/stage1_ko/1000.pt", map_location=device)
ja_ckpt = torch.load("./dist/stage1_ja/1000.pt", map_location=device)
stage2_ckpt = torch.load("./dist/stage2/300.pt", map_location=device)
torch.save(
{
"bridge_kj": stage2_ckpt["bridge_kj"],
"bridge_jk": stage2_ckpt["bridge_jk"],
"config": stage2_ckpt["config"],
"ko": {
"parquet": ko_parquet,
"encoder": ko_ckpt["encoder"],
"decoder": ko_ckpt["decoder"],
"config": ko_ckpt["config"],
},
"ja": {
"parquet": ja_parquet,
"encoder": ja_ckpt["encoder"],
"decoder": ja_ckpt["decoder"],
"config": ja_ckpt["config"],
},
},
"dist/full.pt",
)