Upload teacher_code/builder.py with huggingface_hub
Browse files- teacher_code/builder.py +81 -0
teacher_code/builder.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from .clip_encoder import CLIPVisionTower
|
| 3 |
+
from .imagebind import ImageBindWrapper
|
| 4 |
+
from .open_clip_encoder import OpenCLIPVisionTower
|
| 5 |
+
from .siglip_encoder import SigLipVisionTower
|
| 6 |
+
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
|
| 7 |
+
|
| 8 |
+
from .eva_clip.eva_clip_encoder import EvaClipVisionTower
|
| 9 |
+
from .dev_eva_clip.eva_vit import EvaViTWrapper
|
| 10 |
+
|
| 11 |
+
from blip3o.model.nextdit_crossattn import NextDiTCrossAttnConfig, NextDiTCrossAttn
|
| 12 |
+
from blip3o.model.sana_crossattn import SanaCrossAttnConfig, SanaCrossAttn
|
| 13 |
+
|
| 14 |
+
from diffusers.models import AutoencoderKL
|
| 15 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler
|
| 16 |
+
from diffusers import SanaTransformer2DModel
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
| 20 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
| 21 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
| 22 |
+
use_s2 = getattr(vision_tower_cfg, 's2', False)
|
| 23 |
+
if "siglip" in vision_tower:
|
| 24 |
+
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
| 25 |
+
if "eva" in vision_tower:
|
| 26 |
+
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
| 27 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
|
| 28 |
+
if use_s2:
|
| 29 |
+
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
|
| 30 |
+
else:
|
| 31 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
| 32 |
+
|
| 33 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def build_gen_vision_tower(vision_tower_cfg, **kwargs):
|
| 39 |
+
vision_tower = getattr(vision_tower_cfg, 'gen_vision_tower')
|
| 40 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
| 41 |
+
use_s2 = getattr(vision_tower_cfg, 's2', False)
|
| 42 |
+
if "siglip" in vision_tower:
|
| 43 |
+
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
| 44 |
+
if "eva" in vision_tower:
|
| 45 |
+
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
| 46 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
|
| 47 |
+
if use_s2:
|
| 48 |
+
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
|
| 49 |
+
else:
|
| 50 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
| 51 |
+
|
| 52 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def build_dit(vision_tower_cfg, **kwargs):
|
| 57 |
+
vision_tower_cfg.hidden_size = 896
|
| 58 |
+
print("="*20, "vision tower config", vision_tower_cfg, "="*20)
|
| 59 |
+
# if not hasattr(vision_tower_cfg, "hidden_size"):
|
| 60 |
+
# if "3B" in vision_tower_cfg.model_name_or_path:
|
| 61 |
+
# vision_tower_cfg.hidden_size = 2048
|
| 62 |
+
# elif "7B" in vision_tower_cfg.model_name_or_path:
|
| 63 |
+
# vision_tower_cfg.hidden_size = 3584
|
| 64 |
+
# else:
|
| 65 |
+
# vision_tower_cfg.hidden_size = 3072
|
| 66 |
+
print("="*20, "Building SANA with hidden size", vision_tower_cfg.hidden_size, "="*20)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
#dit = NextDiTCrossAttn(NextDiTCrossAttnConfig(latent_embedding_size=vision_tower_cfg.hidden_size))
|
| 70 |
+
#noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
|
| 71 |
+
#dit = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_600M_512px_diffusers", subfolder="transformer", torch_dtype=torch.float16,)
|
| 72 |
+
|
| 73 |
+
# dit = SanaCrossAttn(SanaCrossAttnConfig()) #cross_attention_dim=vision_tower_cfg.hidden_size))
|
| 74 |
+
#dit = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_600M_512px_diffusers", subfolder="transformer")
|
| 75 |
+
dit = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_600M_1024px_diffusers",device_map="cpu", subfolder="transformer")
|
| 76 |
+
dit_teacher = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers", device_map="cpu", subfolder="transformer")
|
| 77 |
+
noise_scheduler = DPMSolverMultistepScheduler.from_pretrained("Efficient-Large-Model/Sana_600M_512px_diffusers",subfolder="scheduler")
|
| 78 |
+
|
| 79 |
+
return dit, dit_teacher, noise_scheduler
|
| 80 |
+
|
| 81 |
+
|