import torch import json from typing import Literal from pydantic import BaseModel FP32_STR = ["float32", "fp32"] FP16_STR = ["float16", "fp16", "half"] BF16_STR = ["bfloat16", "bf16"] def str_to_dtype(dtype_str: str) -> torch.dtype: dtype_str = dtype_str.lower() if dtype_str in FP32_STR: return torch.float32 elif dtype_str in FP16_STR: return torch.float16 elif dtype_str in BF16_STR: return torch.bfloat16 else: raise ValueError(f"Unsupported dtype string: {dtype_str}") class DenoiserConfig(BaseModel): patch_size: int = 16 in_channels: int = 3 out_channels: int = 3 hidden_size: int = 1024 depth: int = 24 num_heads: int = 16 mlp_ratio: float = 4.0 attn_dropout: float = 0.0 proj_dropout: float = 0.0 bottleneck_dim: int = 128 num_time_tokens: int = 4 rope_theta: float = 256.0 rope_axes_dims: list[int] = [16, 24, 24] rope_axes_lens: list[int] = [256, 128, 128] rope_zero_centered: list[bool] = [False, True, True] context_dim: int class JiT_B_16_Config(DenoiserConfig): patch_size: int = 16 depth: int = 12 hidden_size: int = 768 num_heads: int = 12 bottleneck_dim: int = 128 context_dim: int = 768 rope_axes_dims: list[int] = [16, 24, 24] # sum = 64 = 768 / 12 rope_axes_lens: list[int] = [ 256, # max 256 token text 128, # 2048x2048 image size 128, ] ContextType = Literal["class", "text"] class ClassContextConfig(BaseModel): type: Literal["class"] = "class" label2id_map_path: str @property def label2id(self) -> dict[str, int]: with open(self.label2id_map_path, "r") as f: label2id = json.load(f) return label2id class TextContextConfig(BaseModel): type: Literal["text"] = "text" pretrained_model: str = "p1atdev/Qwen3-VL-2B-Instruct-Text-Only" ContextConfig = ClassContextConfig | TextContextConfig class JiTConfig(BaseModel): dtype: str = "float32" context_encoder: ContextConfig denoiser: DenoiserConfig = JiT_B_16_Config() @property def torch_dtype(self) -> torch.dtype: return str_to_dtype(self.dtype)