KoJaDiffuser / koja_diffuser /runtime /model_loader.py
github-actions[bot]
Sync from GitHub 33c12db74322f3d28409b5dc0a8c441914c9178b
e0552b0
import torch
from koja_diffuser.util import tensor_to_bytes
from koja_diffuser.model import Encoder, Decoder, DiffusionTranslate
from koja_diffuser.tokenizer.ko import KoreanTokenizer
from koja_diffuser.tokenizer.ja import JapaneseTokenizer
from koja_diffuser.config import Stage2Config, Config, get_config
from koja_diffuser.runtime.sample import sample_logits
from typing import Literal
from torch import Tensor
from pathlib import Path
from huggingface_hub import hf_hub_download
import dataclasses
def load_lang(lang: Literal["ko", "ja"], data, *, device: torch.device):
config = get_config(lang, data["config"])
tokenizer = (KoreanTokenizer if lang == "ko" else JapaneseTokenizer)(
tensor_to_bytes(data["parquet"])
)
vocab_size = len(tokenizer)
encoder = Encoder(
vocab_size=vocab_size,
latent_size=config.latent_size,
max_len=config.max_len,
).to(device)
encoder.load_state_dict(data["encoder"])
encoder.eval()
decoder = Decoder(vocab_size=vocab_size, max_len=config.max_len).to(device)
decoder.load_state_dict(data["decoder"])
decoder.eval()
return LangModels(tokenizer, encoder, decoder, device, config)
@dataclasses.dataclass(frozen=True)
class LangModels:
tokenizer: KoreanTokenizer | JapaneseTokenizer
encoder: Encoder
decoder: Decoder
device: torch.device
config: Config
def encode(self, names: list[str], ages: list[int]):
encoded_ids = [
self.tokenizer.encode(
name,
add_eos=True,
max_len=self.config.max_len,
)
for name in names
]
names = torch.tensor(
encoded_ids,
dtype=torch.long,
).to(device=self.device, non_blocking=True)
ages = torch.tensor(ages, dtype=torch.long, device=self.device).clamp(
min=0, max=9
)
return self.encoder(names, ages), encoded_ids
def decode(
self,
names: Tensor,
*,
sampling_mode: Literal["greedy", "sample"] = "sample",
temperature: float = 0.8,
top_k: int = 20,
top_p: float = 0.9,
generator: torch.Generator | None = None,
):
if sampling_mode == "greedy":
logits = self.decoder(names).argmax(dim=-1)
else:
logits = sample_logits(
self.decoder(names), temperature, top_k, top_p, generator
)
decoded: list[list[int]] = logits.cpu().tolist()
res: list[str] = []
for ids in decoded:
res.append(self.tokenizer.decode(ids))
return res, decoded
@dataclasses.dataclass(frozen=True)
class Models:
ko: LangModels
ja: LangModels
bridge_kj: DiffusionTranslate
bridge_jk: DiffusionTranslate
config: Stage2Config
def load_model(*, path="./dist/full.pt", device: torch.device):
if not Path(path).is_file():
path = hf_hub_download(
repo_id="MOKA-AYUMU/KoJaNameDiffuser",
filename="full.pt",
revision="main",
)
ckpt = torch.load(path, map_location=device)
ko = load_lang("ko", ckpt["ko"], device=device)
ja = load_lang("ja", ckpt["ja"], device=device)
bridge_kj = DiffusionTranslate(
source_latent_size=ckpt["ko"]["config"]["latent_size"],
target_latent_size=ckpt["ja"]["config"]["latent_size"],
).to(device=device)
bridge_kj.load_state_dict(ckpt["bridge_kj"])
bridge_kj.eval()
bridge_jk = DiffusionTranslate(
source_latent_size=ckpt["ja"]["config"]["latent_size"],
target_latent_size=ckpt["ko"]["config"]["latent_size"],
).to(device=device)
bridge_jk.load_state_dict(ckpt["bridge_jk"])
bridge_jk.eval()
return Models(ko, ja, bridge_kj, bridge_jk, Stage2Config(**ckpt["config"]))