Spaces:
Runtime error
Runtime error
| import os | |
| import yaml | |
| import torch | |
| from transformers.utils.hub import get_file_from_repo | |
| from .clip.clip_encoder import CLIPVisionTower | |
| from .eva_clip.eva_clip_encoder import EvaClipVisionTower | |
| from .internvit.internvit_encoder import InternViTVisionTower | |
| from .siglip.siglip_encoder import SiglipVisionTower, SiglipVisionTowerS2 | |
| from .whale.init_model import init_model | |
| def build_vision_tower(vision_tower_cfg, **kwargs): | |
| vision_tower = getattr( | |
| vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None) | |
| ) | |
| use_s2 = getattr(vision_tower_cfg, "use_s2", False) | |
| if "sig" in vision_tower.lower(): | |
| if use_s2: | |
| return SiglipVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) | |
| else: | |
| return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) | |
| elif "eva" in vision_tower.lower(): | |
| if use_s2: | |
| raise ValueError(f"Currently not supporting S2 for EVA-CLIP") | |
| else: | |
| return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) | |
| elif "clip" in vision_tower.lower(): | |
| if use_s2: | |
| raise ValueError(f"Currently not supporting S2 for CLIP") | |
| else: | |
| return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) | |
| elif "internvit" in vision_tower.lower(): | |
| if use_s2: | |
| raise ValueError(f"Currently not supporting S2 for InternViT") | |
| else: | |
| return InternViTVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) | |
| else: | |
| raise ValueError(f"Unknown vision tower: {vision_tower}") | |
| def build_audio_encoder(audio_encoder_config, **kwargs): | |
| with open(get_file_from_repo(audio_encoder_config.mm_audio_encoder, "train.yaml"), "r") as fin: | |
| configs = yaml.load(fin, Loader=yaml.FullLoader) | |
| configs["cmvn_file"] = get_file_from_repo(audio_encoder_config.mm_audio_encoder, "global_cmvn") | |
| configs["model_conf"]["freeze_encoder"] = getattr( | |
| audio_encoder_config, "freeze_audio_encoder", True | |
| ) | |
| configs["model_conf"]["freeze_adpter"] = getattr( | |
| audio_encoder_config, "freeze_audio_encoder_adapter", True | |
| ) | |
| configs["model_conf"]["audio_prompt_finetune"] = getattr( | |
| audio_encoder_config, "audio_prompt_finetune", False | |
| ) | |
| configs["model_conf"]["audio_prompt_num"] = getattr( | |
| audio_encoder_config, "audio_prompt_num", 0 | |
| ) | |
| audio_encoder = init_model(configs) | |
| checkpoint = torch.load(get_file_from_repo(audio_encoder_config.mm_audio_encoder, "final.pt"), map_location="cpu") | |
| model_dict = audio_encoder.state_dict() | |
| for key in model_dict.keys(): | |
| if key in checkpoint.keys(): | |
| if model_dict[key].shape == checkpoint[key].shape: | |
| model_dict[key] = checkpoint[key] | |
| else: | |
| print( | |
| "Key {} has different shape, {} VS {}".format( | |
| key, model_dict[key].shape, checkpoint[key].shape | |
| ) | |
| ) | |
| else: | |
| print("Key {} has not in resume model".format(key)) | |
| audio_encoder.load_state_dict(model_dict) | |
| return audio_encoder | |