Spaces:
Sleeping
Sleeping
File size: 1,014 Bytes
e0552b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import torch
from koja_diffuser.util import file_to_tensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ko_parquet = file_to_tensor("./dist/ko_token.parquet")
ja_parquet = file_to_tensor("./dist/ja_token.parquet")
ko_ckpt = torch.load("./dist/stage1_ko/1000.pt", map_location=device)
ja_ckpt = torch.load("./dist/stage1_ja/1000.pt", map_location=device)
stage2_ckpt = torch.load("./dist/stage2/300.pt", map_location=device)
torch.save(
{
"bridge_kj": stage2_ckpt["bridge_kj"],
"bridge_jk": stage2_ckpt["bridge_jk"],
"config": stage2_ckpt["config"],
"ko": {
"parquet": ko_parquet,
"encoder": ko_ckpt["encoder"],
"decoder": ko_ckpt["decoder"],
"config": ko_ckpt["config"],
},
"ja": {
"parquet": ja_parquet,
"encoder": ja_ckpt["encoder"],
"decoder": ja_ckpt["decoder"],
"config": ja_ckpt["config"],
},
},
"dist/full.pt",
)
|