File size: 2,208 Bytes
4b08319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import json

from typing import Literal
from pydantic import BaseModel


FP32_STR = ["float32", "fp32"]
FP16_STR = ["float16", "fp16", "half"]
BF16_STR = ["bfloat16", "bf16"]


def str_to_dtype(dtype_str: str) -> torch.dtype:
    dtype_str = dtype_str.lower()
    if dtype_str in FP32_STR:
        return torch.float32
    elif dtype_str in FP16_STR:
        return torch.float16
    elif dtype_str in BF16_STR:
        return torch.bfloat16
    else:
        raise ValueError(f"Unsupported dtype string: {dtype_str}")


class DenoiserConfig(BaseModel):
    patch_size: int = 16
    in_channels: int = 3
    out_channels: int = 3
    hidden_size: int = 1024
    depth: int = 24
    num_heads: int = 16
    mlp_ratio: float = 4.0
    attn_dropout: float = 0.0
    proj_dropout: float = 0.0

    bottleneck_dim: int = 128
    num_time_tokens: int = 4

    rope_theta: float = 256.0
    rope_axes_dims: list[int] = [16, 24, 24]
    rope_axes_lens: list[int] = [256, 128, 128]
    rope_zero_centered: list[bool] = [False, True, True]

    context_dim: int


class JiT_B_16_Config(DenoiserConfig):
    patch_size: int = 16

    depth: int = 12
    hidden_size: int = 768
    num_heads: int = 12
    bottleneck_dim: int = 128

    context_dim: int = 768

    rope_axes_dims: list[int] = [16, 24, 24]  # sum = 64 = 768 / 12
    rope_axes_lens: list[int] = [
        256,  # max 256 token text
        128,  # 2048x2048 image size
        128,
    ]


ContextType = Literal["class", "text"]


class ClassContextConfig(BaseModel):
    type: Literal["class"] = "class"
    label2id_map_path: str

    @property
    def label2id(self) -> dict[str, int]:
        with open(self.label2id_map_path, "r") as f:
            label2id = json.load(f)

        return label2id


class TextContextConfig(BaseModel):
    type: Literal["text"] = "text"
    pretrained_model: str = "p1atdev/Qwen3-VL-2B-Instruct-Text-Only"


ContextConfig = ClassContextConfig | TextContextConfig


class JiTConfig(BaseModel):
    dtype: str = "float32"

    context_encoder: ContextConfig
    denoiser: DenoiserConfig = JiT_B_16_Config()

    @property
    def torch_dtype(self) -> torch.dtype:
        return str_to_dtype(self.dtype)