drkostas commited on
Commit
69c0e29
·
verified ·
1 Parent(s): 484b5bb

Upload src/models/__init__.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/models/__init__.py +53 -0
src/models/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ["build_teacher", "build_student"]
2
+
3
+ from typing import Dict, Any
4
+
5
+ from .vision_transformer import VisionTransformerMIM
6
+
7
+
8
+ def build_student(cfg: Dict[str, Any]) -> VisionTransformerMIM:
9
+ """Builds the student ViT model from config."""
10
+ student_cfg = cfg["model"]["student"]
11
+
12
+ use_mask_tokens = student_cfg.get("use_mask_tokens", True)
13
+
14
+ # Sparse mode requires absolute position embeddings (variable sequence length)
15
+ if not use_mask_tokens:
16
+ default_abs_pos = True
17
+ default_sincos_pos = False
18
+ default_rel_pos = False
19
+ default_shared_rel = False
20
+ else:
21
+ # Dense mode can use relative position embeddings
22
+ default_abs_pos = False
23
+ default_sincos_pos = False
24
+ default_rel_pos = False
25
+ default_shared_rel = True
26
+
27
+ model = VisionTransformerMIM(
28
+ img_size=student_cfg["img_size"],
29
+ patch_size=student_cfg["patch_size"],
30
+ embed_dim=student_cfg["embed_dim"],
31
+ depth=student_cfg["depth"],
32
+ num_heads=student_cfg["num_heads"],
33
+ mlp_ratio=student_cfg.get("mlp_ratio", 4.0),
34
+ drop_path_rate=student_cfg.get("drop_path_rate", 0.1),
35
+ init_values=student_cfg.get("init_values", 0.1),
36
+ use_abs_pos_emb=student_cfg.get("use_abs_pos_emb", default_abs_pos),
37
+ use_sincos_pos_emb=student_cfg.get("use_sincos_pos_emb", default_sincos_pos),
38
+ use_shared_rel_pos_bias=student_cfg.get("use_shared_rel_pos_bias", default_shared_rel),
39
+ use_rel_pos_bias=student_cfg.get("use_rel_pos_bias", default_rel_pos),
40
+ use_mask_tokens=use_mask_tokens,
41
+ )
42
+ return model
43
+
44
+
45
+ def build_teacher(cfg: Dict[str, Any]):
46
+ """Builds the frozen CLIP teacher model. Import is lazy to avoid slow CLIP loading at import time."""
47
+ from .clip_teacher import ClipTeacher
48
+ teacher_cfg = cfg["model"]["teacher"]
49
+ return ClipTeacher(
50
+ model_name=teacher_cfg["name"],
51
+ layer_extraction=teacher_cfg.get("layer_extraction", "last"),
52
+ num_layers_to_extract=teacher_cfg.get("num_layers_to_extract", 1),
53
+ )