Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,208 Bytes
4b08319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
import json
from typing import Literal
from pydantic import BaseModel
FP32_STR = ["float32", "fp32"]
FP16_STR = ["float16", "fp16", "half"]
BF16_STR = ["bfloat16", "bf16"]
def str_to_dtype(dtype_str: str) -> torch.dtype:
dtype_str = dtype_str.lower()
if dtype_str in FP32_STR:
return torch.float32
elif dtype_str in FP16_STR:
return torch.float16
elif dtype_str in BF16_STR:
return torch.bfloat16
else:
raise ValueError(f"Unsupported dtype string: {dtype_str}")
class DenoiserConfig(BaseModel):
patch_size: int = 16
in_channels: int = 3
out_channels: int = 3
hidden_size: int = 1024
depth: int = 24
num_heads: int = 16
mlp_ratio: float = 4.0
attn_dropout: float = 0.0
proj_dropout: float = 0.0
bottleneck_dim: int = 128
num_time_tokens: int = 4
rope_theta: float = 256.0
rope_axes_dims: list[int] = [16, 24, 24]
rope_axes_lens: list[int] = [256, 128, 128]
rope_zero_centered: list[bool] = [False, True, True]
context_dim: int
class JiT_B_16_Config(DenoiserConfig):
patch_size: int = 16
depth: int = 12
hidden_size: int = 768
num_heads: int = 12
bottleneck_dim: int = 128
context_dim: int = 768
rope_axes_dims: list[int] = [16, 24, 24] # sum = 64 = 768 / 12
rope_axes_lens: list[int] = [
256, # max 256 token text
128, # 2048x2048 image size
128,
]
ContextType = Literal["class", "text"]
class ClassContextConfig(BaseModel):
type: Literal["class"] = "class"
label2id_map_path: str
@property
def label2id(self) -> dict[str, int]:
with open(self.label2id_map_path, "r") as f:
label2id = json.load(f)
return label2id
class TextContextConfig(BaseModel):
type: Literal["text"] = "text"
pretrained_model: str = "p1atdev/Qwen3-VL-2B-Instruct-Text-Only"
ContextConfig = ClassContextConfig | TextContextConfig
class JiTConfig(BaseModel):
dtype: str = "float32"
context_encoder: ContextConfig
denoiser: DenoiserConfig = JiT_B_16_Config()
@property
def torch_dtype(self) -> torch.dtype:
return str_to_dtype(self.dtype)
|