Spaces:
Sleeping
Sleeping
File size: 4,094 Bytes
e0552b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from koja_diffuser.runtime.model_loader import load_model
from koja_diffuser.runtime.schedule import DiffusionSchedule
from koja_diffuser.runtime.ddim import ddim_sample_bridge, DdimConfig
import torch
from koja_diffuser.util import Emitter, noop_emitter, tensor_norm
from typing import Literal
class Inference:
def __init__(
self,
*,
device: torch.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
):
self.model = load_model(device=device)
self.schedule = DiffusionSchedule(
timesteps=self.model.config.diffusion_timesteps
).to(device=device)
self.device = device
def resolve_seed(self, seed: int | None) -> int:
if seed is None or seed < 0:
return int(
torch.randint(
low=0,
high=2**31 - 1,
size=(1,),
device="cpu",
).item()
)
return int(seed)
def make_generator(self, seed: int | None) -> tuple[torch.Generator, int]:
used_seed = self.resolve_seed(seed)
generator = torch.Generator(device=self.device)
generator.manual_seed(used_seed)
return generator, used_seed
@torch.inference_mode()
async def ko_to_ja(
self,
names: list[str],
ages: list[int],
*,
seed: int | None = None,
num_steps: int = 6,
start_timestep: int = 500,
sampling_mode: Literal["greedy", "sample"] = "sample",
temperature: float = 0.8,
top_k: int = 20,
top_p: float = 0.9,
emit: Emitter = noop_emitter,
):
generator, seed = self.make_generator(seed)
await emit("info", {"type": "ko_to_ja", "seed": seed})
guide, names_ids = self.model.ko.encode(names, ages)
await emit(
"encoded",
{"guide": tensor_norm(guide), "names_ids": tensor_norm(names_ids)},
)
z_ja_hat = await ddim_sample_bridge(
bridge=self.model.bridge_kj,
generator=generator,
schedule=self.schedule,
guide=guide,
config=DdimConfig(start_timestep, num_steps),
emit=emit,
)
decoded, decoded_ids = self.model.ja.decode(
z_ja_hat,
sampling_mode=sampling_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
generator=self.make_generator(seed)[0],
)
await emit(
"decoded", {"result": decoded, "names_ids": tensor_norm(decoded_ids)}
)
return decoded
@torch.inference_mode()
async def ja_to_ko(
self,
names: list[str],
ages: list[int],
*,
seed: int | None = None,
num_steps: int = 6,
start_timestep: int = 500,
sampling_mode: Literal["greedy", "sample"] = "sample",
temperature: float = 0.8,
top_k: int = 20,
top_p: float = 0.9,
emit: Emitter = noop_emitter,
):
generator, seed = self.make_generator(seed)
await emit("info", {"type": "ja_to_ko", "seed": seed})
guide, names_ids = self.model.ja.encode(names, ages)
await emit(
"encoded",
{"guide": tensor_norm(guide), "names_ids": tensor_norm(names_ids)},
)
z_ko_hat = await ddim_sample_bridge(
bridge=self.model.bridge_jk,
generator=generator,
schedule=self.schedule,
guide=guide,
config=DdimConfig(start_timestep, num_steps),
emit=emit,
)
decoded, decoded_ids = self.model.ko.decode(
z_ko_hat,
sampling_mode=sampling_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
generator=self.make_generator(seed)[0],
)
await emit(
"decoded", {"result": decoded, "names_ids": tensor_norm(decoded_ids)}
)
return decoded
|