File size: 1,014 Bytes
e0552b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from koja_diffuser.util import file_to_tensor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ko_parquet = file_to_tensor("./dist/ko_token.parquet")
ja_parquet = file_to_tensor("./dist/ja_token.parquet")

ko_ckpt = torch.load("./dist/stage1_ko/1000.pt", map_location=device)
ja_ckpt = torch.load("./dist/stage1_ja/1000.pt", map_location=device)
stage2_ckpt = torch.load("./dist/stage2/300.pt", map_location=device)


torch.save(
    {
        "bridge_kj": stage2_ckpt["bridge_kj"],
        "bridge_jk": stage2_ckpt["bridge_jk"],
        "config": stage2_ckpt["config"],
        "ko": {
            "parquet": ko_parquet,
            "encoder": ko_ckpt["encoder"],
            "decoder": ko_ckpt["decoder"],
            "config": ko_ckpt["config"],
        },
        "ja": {
            "parquet": ja_parquet,
            "encoder": ja_ckpt["encoder"],
            "decoder": ja_ckpt["decoder"],
            "config": ja_ckpt["config"],
        },
    },
    "dist/full.pt",
)