Fahad-S commited on
Commit
e07a694
·
verified ·
1 Parent(s): 874d7d9

Upload teacher_code/builder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. teacher_code/builder.py +81 -0
teacher_code/builder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+ from .imagebind import ImageBindWrapper
4
+ from .open_clip_encoder import OpenCLIPVisionTower
5
+ from .siglip_encoder import SigLipVisionTower
6
+ from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
7
+
8
+ from .eva_clip.eva_clip_encoder import EvaClipVisionTower
9
+ from .dev_eva_clip.eva_vit import EvaViTWrapper
10
+
11
+ from blip3o.model.nextdit_crossattn import NextDiTCrossAttnConfig, NextDiTCrossAttn
12
+ from blip3o.model.sana_crossattn import SanaCrossAttnConfig, SanaCrossAttn
13
+
14
+ from diffusers.models import AutoencoderKL
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler
16
+ from diffusers import SanaTransformer2DModel
17
+ import torch
18
+
19
+ def build_vision_tower(vision_tower_cfg, **kwargs):
20
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
21
+ is_absolute_path_exists = os.path.exists(vision_tower)
22
+ use_s2 = getattr(vision_tower_cfg, 's2', False)
23
+ if "siglip" in vision_tower:
24
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
25
+ if "eva" in vision_tower:
26
+ return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
27
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
28
+ if use_s2:
29
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
30
+ else:
31
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
32
+
33
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
34
+
35
+
36
+
37
+
38
+ def build_gen_vision_tower(vision_tower_cfg, **kwargs):
39
+ vision_tower = getattr(vision_tower_cfg, 'gen_vision_tower')
40
+ is_absolute_path_exists = os.path.exists(vision_tower)
41
+ use_s2 = getattr(vision_tower_cfg, 's2', False)
42
+ if "siglip" in vision_tower:
43
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
44
+ if "eva" in vision_tower:
45
+ return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
46
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
47
+ if use_s2:
48
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
49
+ else:
50
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
51
+
52
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
53
+
54
+
55
+
56
+ def build_dit(vision_tower_cfg, **kwargs):
57
+ vision_tower_cfg.hidden_size = 896
58
+ print("="*20, "vision tower config", vision_tower_cfg, "="*20)
59
+ # if not hasattr(vision_tower_cfg, "hidden_size"):
60
+ # if "3B" in vision_tower_cfg.model_name_or_path:
61
+ # vision_tower_cfg.hidden_size = 2048
62
+ # elif "7B" in vision_tower_cfg.model_name_or_path:
63
+ # vision_tower_cfg.hidden_size = 3584
64
+ # else:
65
+ # vision_tower_cfg.hidden_size = 3072
66
+ print("="*20, "Building SANA with hidden size", vision_tower_cfg.hidden_size, "="*20)
67
+
68
+
69
+ #dit = NextDiTCrossAttn(NextDiTCrossAttnConfig(latent_embedding_size=vision_tower_cfg.hidden_size))
70
+ #noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
71
+ #dit = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_600M_512px_diffusers", subfolder="transformer", torch_dtype=torch.float16,)
72
+
73
+ # dit = SanaCrossAttn(SanaCrossAttnConfig()) #cross_attention_dim=vision_tower_cfg.hidden_size))
74
+ #dit = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_600M_512px_diffusers", subfolder="transformer")
75
+ dit = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_600M_1024px_diffusers",device_map="cpu", subfolder="transformer")
76
+ dit_teacher = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers", device_map="cpu", subfolder="transformer")
77
+ noise_scheduler = DPMSolverMultistepScheduler.from_pretrained("Efficient-Large-Model/Sana_600M_512px_diffusers",subfolder="scheduler")
78
+
79
+ return dit, dit_teacher, noise_scheduler
80
+
81
+