Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import json | |
| from typing import Literal | |
| from pydantic import BaseModel | |
| FP32_STR = ["float32", "fp32"] | |
| FP16_STR = ["float16", "fp16", "half"] | |
| BF16_STR = ["bfloat16", "bf16"] | |
| def str_to_dtype(dtype_str: str) -> torch.dtype: | |
| dtype_str = dtype_str.lower() | |
| if dtype_str in FP32_STR: | |
| return torch.float32 | |
| elif dtype_str in FP16_STR: | |
| return torch.float16 | |
| elif dtype_str in BF16_STR: | |
| return torch.bfloat16 | |
| else: | |
| raise ValueError(f"Unsupported dtype string: {dtype_str}") | |
| class DenoiserConfig(BaseModel): | |
| patch_size: int = 16 | |
| in_channels: int = 3 | |
| out_channels: int = 3 | |
| hidden_size: int = 1024 | |
| depth: int = 24 | |
| num_heads: int = 16 | |
| mlp_ratio: float = 4.0 | |
| attn_dropout: float = 0.0 | |
| proj_dropout: float = 0.0 | |
| bottleneck_dim: int = 128 | |
| num_time_tokens: int = 4 | |
| rope_theta: float = 256.0 | |
| rope_axes_dims: list[int] = [16, 24, 24] | |
| rope_axes_lens: list[int] = [256, 128, 128] | |
| rope_zero_centered: list[bool] = [False, True, True] | |
| context_dim: int | |
| class JiT_B_16_Config(DenoiserConfig): | |
| patch_size: int = 16 | |
| depth: int = 12 | |
| hidden_size: int = 768 | |
| num_heads: int = 12 | |
| bottleneck_dim: int = 128 | |
| context_dim: int = 768 | |
| rope_axes_dims: list[int] = [16, 24, 24] # sum = 64 = 768 / 12 | |
| rope_axes_lens: list[int] = [ | |
| 256, # max 256 token text | |
| 128, # 2048x2048 image size | |
| 128, | |
| ] | |
| ContextType = Literal["class", "text"] | |
| class ClassContextConfig(BaseModel): | |
| type: Literal["class"] = "class" | |
| label2id_map_path: str | |
| def label2id(self) -> dict[str, int]: | |
| with open(self.label2id_map_path, "r") as f: | |
| label2id = json.load(f) | |
| return label2id | |
| class TextContextConfig(BaseModel): | |
| type: Literal["text"] = "text" | |
| pretrained_model: str = "p1atdev/Qwen3-VL-2B-Instruct-Text-Only" | |
| ContextConfig = ClassContextConfig | TextContextConfig | |
| class JiTConfig(BaseModel): | |
| dtype: str = "float32" | |
| context_encoder: ContextConfig | |
| denoiser: DenoiserConfig = JiT_B_16_Config() | |
| def torch_dtype(self) -> torch.dtype: | |
| return str_to_dtype(self.dtype) | |