Spaces:
Sleeping
Sleeping
| import torch | |
| from koja_diffuser.util import tensor_to_bytes | |
| from koja_diffuser.model import Encoder, Decoder, DiffusionTranslate | |
| from koja_diffuser.tokenizer.ko import KoreanTokenizer | |
| from koja_diffuser.tokenizer.ja import JapaneseTokenizer | |
| from koja_diffuser.config import Stage2Config, Config, get_config | |
| from koja_diffuser.runtime.sample import sample_logits | |
| from typing import Literal | |
| from torch import Tensor | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| import dataclasses | |
| def load_lang(lang: Literal["ko", "ja"], data, *, device: torch.device): | |
| config = get_config(lang, data["config"]) | |
| tokenizer = (KoreanTokenizer if lang == "ko" else JapaneseTokenizer)( | |
| tensor_to_bytes(data["parquet"]) | |
| ) | |
| vocab_size = len(tokenizer) | |
| encoder = Encoder( | |
| vocab_size=vocab_size, | |
| latent_size=config.latent_size, | |
| max_len=config.max_len, | |
| ).to(device) | |
| encoder.load_state_dict(data["encoder"]) | |
| encoder.eval() | |
| decoder = Decoder(vocab_size=vocab_size, max_len=config.max_len).to(device) | |
| decoder.load_state_dict(data["decoder"]) | |
| decoder.eval() | |
| return LangModels(tokenizer, encoder, decoder, device, config) | |
| class LangModels: | |
| tokenizer: KoreanTokenizer | JapaneseTokenizer | |
| encoder: Encoder | |
| decoder: Decoder | |
| device: torch.device | |
| config: Config | |
| def encode(self, names: list[str], ages: list[int]): | |
| encoded_ids = [ | |
| self.tokenizer.encode( | |
| name, | |
| add_eos=True, | |
| max_len=self.config.max_len, | |
| ) | |
| for name in names | |
| ] | |
| names = torch.tensor( | |
| encoded_ids, | |
| dtype=torch.long, | |
| ).to(device=self.device, non_blocking=True) | |
| ages = torch.tensor(ages, dtype=torch.long, device=self.device).clamp( | |
| min=0, max=9 | |
| ) | |
| return self.encoder(names, ages), encoded_ids | |
| def decode( | |
| self, | |
| names: Tensor, | |
| *, | |
| sampling_mode: Literal["greedy", "sample"] = "sample", | |
| temperature: float = 0.8, | |
| top_k: int = 20, | |
| top_p: float = 0.9, | |
| generator: torch.Generator | None = None, | |
| ): | |
| if sampling_mode == "greedy": | |
| logits = self.decoder(names).argmax(dim=-1) | |
| else: | |
| logits = sample_logits( | |
| self.decoder(names), temperature, top_k, top_p, generator | |
| ) | |
| decoded: list[list[int]] = logits.cpu().tolist() | |
| res: list[str] = [] | |
| for ids in decoded: | |
| res.append(self.tokenizer.decode(ids)) | |
| return res, decoded | |
| class Models: | |
| ko: LangModels | |
| ja: LangModels | |
| bridge_kj: DiffusionTranslate | |
| bridge_jk: DiffusionTranslate | |
| config: Stage2Config | |
| def load_model(*, path="./dist/full.pt", device: torch.device): | |
| if not Path(path).is_file(): | |
| path = hf_hub_download( | |
| repo_id="MOKA-AYUMU/KoJaNameDiffuser", | |
| filename="full.pt", | |
| revision="main", | |
| ) | |
| ckpt = torch.load(path, map_location=device) | |
| ko = load_lang("ko", ckpt["ko"], device=device) | |
| ja = load_lang("ja", ckpt["ja"], device=device) | |
| bridge_kj = DiffusionTranslate( | |
| source_latent_size=ckpt["ko"]["config"]["latent_size"], | |
| target_latent_size=ckpt["ja"]["config"]["latent_size"], | |
| ).to(device=device) | |
| bridge_kj.load_state_dict(ckpt["bridge_kj"]) | |
| bridge_kj.eval() | |
| bridge_jk = DiffusionTranslate( | |
| source_latent_size=ckpt["ja"]["config"]["latent_size"], | |
| target_latent_size=ckpt["ko"]["config"]["latent_size"], | |
| ).to(device=device) | |
| bridge_jk.load_state_dict(ckpt["bridge_jk"]) | |
| bridge_jk.eval() | |
| return Models(ko, ja, bridge_kj, bridge_jk, Stage2Config(**ckpt["config"])) | |