Plat
init
4b08319
import torch
import json
from typing import Literal
from pydantic import BaseModel
FP32_STR = ["float32", "fp32"]
FP16_STR = ["float16", "fp16", "half"]
BF16_STR = ["bfloat16", "bf16"]
def str_to_dtype(dtype_str: str) -> torch.dtype:
dtype_str = dtype_str.lower()
if dtype_str in FP32_STR:
return torch.float32
elif dtype_str in FP16_STR:
return torch.float16
elif dtype_str in BF16_STR:
return torch.bfloat16
else:
raise ValueError(f"Unsupported dtype string: {dtype_str}")
class DenoiserConfig(BaseModel):
patch_size: int = 16
in_channels: int = 3
out_channels: int = 3
hidden_size: int = 1024
depth: int = 24
num_heads: int = 16
mlp_ratio: float = 4.0
attn_dropout: float = 0.0
proj_dropout: float = 0.0
bottleneck_dim: int = 128
num_time_tokens: int = 4
rope_theta: float = 256.0
rope_axes_dims: list[int] = [16, 24, 24]
rope_axes_lens: list[int] = [256, 128, 128]
rope_zero_centered: list[bool] = [False, True, True]
context_dim: int
class JiT_B_16_Config(DenoiserConfig):
patch_size: int = 16
depth: int = 12
hidden_size: int = 768
num_heads: int = 12
bottleneck_dim: int = 128
context_dim: int = 768
rope_axes_dims: list[int] = [16, 24, 24] # sum = 64 = 768 / 12
rope_axes_lens: list[int] = [
256, # max 256 token text
128, # 2048x2048 image size
128,
]
ContextType = Literal["class", "text"]
class ClassContextConfig(BaseModel):
type: Literal["class"] = "class"
label2id_map_path: str
@property
def label2id(self) -> dict[str, int]:
with open(self.label2id_map_path, "r") as f:
label2id = json.load(f)
return label2id
class TextContextConfig(BaseModel):
type: Literal["text"] = "text"
pretrained_model: str = "p1atdev/Qwen3-VL-2B-Instruct-Text-Only"
ContextConfig = ClassContextConfig | TextContextConfig
class JiTConfig(BaseModel):
dtype: str = "float32"
context_encoder: ContextConfig
denoiser: DenoiserConfig = JiT_B_16_Config()
@property
def torch_dtype(self) -> torch.dtype:
return str_to_dtype(self.dtype)