Upload src/models/__init__.py with huggingface_hub
Browse files- src/models/__init__.py +53 -0
src/models/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = ["build_teacher", "build_student"]
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
|
| 5 |
+
from .vision_transformer import VisionTransformerMIM
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_student(cfg: Dict[str, Any]) -> VisionTransformerMIM:
|
| 9 |
+
"""Builds the student ViT model from config."""
|
| 10 |
+
student_cfg = cfg["model"]["student"]
|
| 11 |
+
|
| 12 |
+
use_mask_tokens = student_cfg.get("use_mask_tokens", True)
|
| 13 |
+
|
| 14 |
+
# Sparse mode requires absolute position embeddings (variable sequence length)
|
| 15 |
+
if not use_mask_tokens:
|
| 16 |
+
default_abs_pos = True
|
| 17 |
+
default_sincos_pos = False
|
| 18 |
+
default_rel_pos = False
|
| 19 |
+
default_shared_rel = False
|
| 20 |
+
else:
|
| 21 |
+
# Dense mode can use relative position embeddings
|
| 22 |
+
default_abs_pos = False
|
| 23 |
+
default_sincos_pos = False
|
| 24 |
+
default_rel_pos = False
|
| 25 |
+
default_shared_rel = True
|
| 26 |
+
|
| 27 |
+
model = VisionTransformerMIM(
|
| 28 |
+
img_size=student_cfg["img_size"],
|
| 29 |
+
patch_size=student_cfg["patch_size"],
|
| 30 |
+
embed_dim=student_cfg["embed_dim"],
|
| 31 |
+
depth=student_cfg["depth"],
|
| 32 |
+
num_heads=student_cfg["num_heads"],
|
| 33 |
+
mlp_ratio=student_cfg.get("mlp_ratio", 4.0),
|
| 34 |
+
drop_path_rate=student_cfg.get("drop_path_rate", 0.1),
|
| 35 |
+
init_values=student_cfg.get("init_values", 0.1),
|
| 36 |
+
use_abs_pos_emb=student_cfg.get("use_abs_pos_emb", default_abs_pos),
|
| 37 |
+
use_sincos_pos_emb=student_cfg.get("use_sincos_pos_emb", default_sincos_pos),
|
| 38 |
+
use_shared_rel_pos_bias=student_cfg.get("use_shared_rel_pos_bias", default_shared_rel),
|
| 39 |
+
use_rel_pos_bias=student_cfg.get("use_rel_pos_bias", default_rel_pos),
|
| 40 |
+
use_mask_tokens=use_mask_tokens,
|
| 41 |
+
)
|
| 42 |
+
return model
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def build_teacher(cfg: Dict[str, Any]):
|
| 46 |
+
"""Builds the frozen CLIP teacher model. Import is lazy to avoid slow CLIP loading at import time."""
|
| 47 |
+
from .clip_teacher import ClipTeacher
|
| 48 |
+
teacher_cfg = cfg["model"]["teacher"]
|
| 49 |
+
return ClipTeacher(
|
| 50 |
+
model_name=teacher_cfg["name"],
|
| 51 |
+
layer_extraction=teacher_cfg.get("layer_extraction", "last"),
|
| 52 |
+
num_layers_to_extract=teacher_cfg.get("num_layers_to_extract", 1),
|
| 53 |
+
)
|