Spaces:
Running
Running
| from typing import Dict, Optional | |
| import torch | |
| import torch.distributed as dist | |
| from torch import nn, Tensor | |
| from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig | |
| from peft import LoraConfig, get_peft_model, PeftModel | |
| from src.arguments import ModelArguments | |
| from src.vlm_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM | |
| from src.vlm_backbone.llava_next import LlavaNextForConditionalGeneration | |
| from transformers import Qwen2VLForConditionalGeneration | |
| class MMEBModel(nn.Module): | |
| TRANSFORMER_CLS = AutoModelForCausalLM | |
| def __init__(self, | |
| encoder: PreTrainedModel, | |
| pooling: str = 'cls', | |
| normalize: bool = False, | |
| temperature: float = 1.0, | |
| ): | |
| super().__init__() | |
| self.config = encoder.config | |
| self.encoder = encoder | |
| self.pooling = pooling | |
| self.normalize = normalize | |
| self.temperature = temperature | |
| self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') | |
| self.is_ddp = dist.is_initialized() | |
| if self.is_ddp: | |
| self.process_rank = dist.get_rank() | |
| self.world_size = dist.get_world_size() | |
| def encode_input(self, input): | |
| hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) | |
| hidden_states = hidden_states.hidden_states[-1] | |
| pooled_output = self._pooling(hidden_states, input['attention_mask']) | |
| return pooled_output | |
| def _pooling(self, last_hidden_state, attention_mask): | |
| if self.pooling == 'last' or self.pooling == 'eos': | |
| sequence_lengths = attention_mask.sum(dim=1) - 1 | |
| batch_size = last_hidden_state.shape[0] | |
| reps = last_hidden_state[ | |
| torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] | |
| else: | |
| raise NotImplementedError | |
| if self.normalize: | |
| reps = torch.nn.functional.normalize(reps, p=2, dim=-1) | |
| return reps | |
| def build(cls, model_args: ModelArguments, **hf_kwargs): | |
| # Loading the base model | |
| lora_target_modules = None | |
| if model_args.model_backbone == "llava_next": | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| config.use_cache = False | |
| config.padding_side = "left" | |
| base_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| model_args.model_name, | |
| config=config, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| elif model_args.model_backbone == "qwen": | |
| base_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| model_args.model_name, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| base_model.padding_side = "right" | |
| # Loading the base model | |
| elif model_args.model_backbone == "phi35v": | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # config._attn_implementation = "eager" | |
| config.attn_implementation = "flash_attention_2" | |
| config.padding_side = "right" | |
| config.use_cache = False | |
| base_model = Phi3VForCausalLM.from_pretrained( | |
| model_args.model_name, | |
| config=config, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| ) | |
| elif model_args.model_backbone == "internvl_2_5": | |
| # from transformers import InternVLChatConfig, InternVLChatModel | |
| from src.vlm_backbone.intern_vl import InternVLChatConfig, InternVLChatModel | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=True | |
| ) | |
| # import pdb;pdb.set_trace() | |
| config = InternVLChatConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # config.vision_config.image_size = data_args.force_image_size # 假设data_args包含图像尺寸 | |
| config.use_flash_attn = False | |
| base_model = InternVLChatModel.from_pretrained( | |
| model_args.model_name, | |
| config=config, | |
| tokenizer=tokenizer, | |
| # attn_implementation="flash_attention_2", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| lora_target_modules = base_model.get_lora_target_modules() | |
| else: | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| config.use_cache = False | |
| config.padding_side = "right" | |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( | |
| model_args.model_name, **hf_kwargs, config=config, | |
| attn_implementation="flash_attention_2", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True) | |
| base_model.padding_side = "right" | |
| if model_args.lora: | |
| if lora_target_modules is None: | |
| lora_target_modules = model_args.lora_target_modules.split(',') | |
| lora_config = LoraConfig( | |
| r=model_args.lora_r, | |
| lora_alpha=model_args.lora_alpha, | |
| target_modules=lora_target_modules, | |
| lora_dropout=model_args.lora_dropout, | |
| init_lora_weights="gaussian", | |
| use_dora=True, | |
| inference_mode=False | |
| ) | |
| lora_model = get_peft_model(base_model, lora_config) | |
| model = cls( | |
| encoder=lora_model, | |
| pooling=model_args.pooling, | |
| normalize=model_args.normalize, | |
| temperature=model_args.temperature | |
| ) | |
| else: | |
| model = cls( | |
| encoder=base_model, | |
| pooling=model_args.pooling, | |
| normalize=model_args.normalize, | |
| temperature=model_args.temperature | |
| ) | |
| return model | |
| def load(cls, model_args: ModelArguments, **hf_kwargs): | |
| # Loading the base model | |
| checkpoint_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name | |
| if model_args.model_backbone == "llava_next": | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| config.use_cache = False | |
| base_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| model_args.model_name, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| # attn_implementation="flash_attention_2" | |
| ) | |
| base_model.padding_side = "left" | |
| elif model_args.model_backbone == "phi35v": | |
| # Loading the base model | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| config.use_cache = False | |
| config.padding_side = "right" | |
| base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **hf_kwargs, config=config, | |
| attn_implementation="flash_attention_2", | |
| torch_dtype=torch.bfloat16, trust_remote_code=True) | |
| base_model.padding_side = "right" | |
| elif model_args.model_backbone == "internvl_2_5": | |
| print("loading model") | |
| from src.vlm_backbone.intern_vl import InternVLChatConfig, InternVLChatModel | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_args.model_name, | |
| trust_remote_code=True | |
| ) | |
| config = InternVLChatConfig.from_pretrained(model_args.model_name) | |
| # config.vision_config.image_size = data_args.force_image_size | |
| config.use_flash_attn = False | |
| base_model = InternVLChatModel.from_pretrained( | |
| model_args.model_name, | |
| config=config, | |
| tokenizer=tokenizer, | |
| # attn_implementation="flash_attention_2", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| else: | |
| # Loading the base model | |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| config.use_cache = False | |
| config.padding_side = "right" | |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( | |
| checkpoint_path, **hf_kwargs, config=config, | |
| attn_implementation="flash_attention_2", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True) | |
| base_model.padding_side = "right" | |
| # Building the model on top of the base | |
| if model_args.lora: | |
| print("loading lora parameters") | |
| lora_config = LoraConfig.from_pretrained(checkpoint_path) | |
| lora_model = PeftModel.from_pretrained(base_model, checkpoint_path, config=lora_config) | |
| merged_model = lora_model.merge_and_unload() | |
| model = cls( | |
| encoder=merged_model, | |
| pooling=model_args.pooling, | |
| normalize=model_args.normalize | |
| ) | |
| else: | |
| model = cls( | |
| encoder=base_model, | |
| pooling=model_args.pooling, | |
| normalize=model_args.normalize | |
| ) | |
| return model | |
| def save(self, output_dir: str): | |
| self.encoder.save_pretrained(output_dir) | |
| def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, neg: Dict[str, Tensor] = None): | |
| qry_reps = self.encode_input(qry) if qry else None # (bsz_per_device, dim) | |
| tgt_reps = self.encode_input(tgt) if tgt else None # (bsz_per_device, dim) | |
| neg_reps = self.encode_input(neg) if neg else None # (bsz_per_device, dim) | |
| if qry_reps is None or tgt_reps is None: | |
| return {"qry_reps": qry_reps, "tgt_reps": tgt_reps} | |
| # Gather representations if using DDP | |
| if self.is_ddp: | |
| all_qry_reps = self._dist_gather_tensor(qry_reps) | |
| all_tgt_reps = self._dist_gather_tensor(tgt_reps) | |
| all_neg_reps = self._dist_gather_tensor(neg_reps) if neg_reps is not None else None | |
| else: | |
| all_qry_reps = qry_reps | |
| all_tgt_reps = tgt_reps | |
| all_neg_reps = neg_reps | |
| # Compute similarity scores | |
| scores = self.compute_similarity(all_qry_reps, all_tgt_reps) | |
| scores = scores.view(all_qry_reps.size(0), -1) | |
| # Add negative scores if available | |
| if all_neg_reps is not None: | |
| qry_neg_cos = self.compute_similarity(all_qry_reps, all_neg_reps) | |
| scores = torch.cat([scores, qry_neg_cos], dim=1) | |
| # Compute loss | |
| target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) | |
| target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) | |
| loss = self.cross_entropy(scores / self.temperature, target) | |
| if self.is_ddp: | |
| loss = loss * self.world_size | |
| return loss | |
| def _dist_gather_tensor(self, t: Tensor): | |
| t = t.contiguous() | |
| all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] | |
| dist.all_gather(all_tensors, t) | |
| all_tensors[self.process_rank] = t | |
| all_tensors = torch.cat(all_tensors, dim=0) | |
| return all_tensors | |
| def compute_similarity(self, q_reps, p_reps): | |
| return torch.matmul(q_reps, p_reps.transpose(0, 1)) | |