drkostas's picture
Upload src/models/__init__.py with huggingface_hub
69c0e29 verified
__all__ = ["build_teacher", "build_student"]
from typing import Dict, Any
from .vision_transformer import VisionTransformerMIM
def build_student(cfg: Dict[str, Any]) -> VisionTransformerMIM:
"""Builds the student ViT model from config."""
student_cfg = cfg["model"]["student"]
use_mask_tokens = student_cfg.get("use_mask_tokens", True)
# Sparse mode requires absolute position embeddings (variable sequence length)
if not use_mask_tokens:
default_abs_pos = True
default_sincos_pos = False
default_rel_pos = False
default_shared_rel = False
else:
# Dense mode can use relative position embeddings
default_abs_pos = False
default_sincos_pos = False
default_rel_pos = False
default_shared_rel = True
model = VisionTransformerMIM(
img_size=student_cfg["img_size"],
patch_size=student_cfg["patch_size"],
embed_dim=student_cfg["embed_dim"],
depth=student_cfg["depth"],
num_heads=student_cfg["num_heads"],
mlp_ratio=student_cfg.get("mlp_ratio", 4.0),
drop_path_rate=student_cfg.get("drop_path_rate", 0.1),
init_values=student_cfg.get("init_values", 0.1),
use_abs_pos_emb=student_cfg.get("use_abs_pos_emb", default_abs_pos),
use_sincos_pos_emb=student_cfg.get("use_sincos_pos_emb", default_sincos_pos),
use_shared_rel_pos_bias=student_cfg.get("use_shared_rel_pos_bias", default_shared_rel),
use_rel_pos_bias=student_cfg.get("use_rel_pos_bias", default_rel_pos),
use_mask_tokens=use_mask_tokens,
)
return model
def build_teacher(cfg: Dict[str, Any]):
"""Builds the frozen CLIP teacher model. Import is lazy to avoid slow CLIP loading at import time."""
from .clip_teacher import ClipTeacher
teacher_cfg = cfg["model"]["teacher"]
return ClipTeacher(
model_name=teacher_cfg["name"],
layer_extraction=teacher_cfg.get("layer_extraction", "last"),
num_layers_to_extract=teacher_cfg.get("num_layers_to_extract", 1),
)