Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import os | |
| from peft import PeftModel | |
| from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer | |
| from .projection_layer import ProjectionBlock | |
| class MultimodalPhiModel(PreTrainedModel): | |
| def gradient_checkpointing_enable(self, **kwargs): | |
| self.phi_model.gradient_checkpointing_enable(**kwargs) | |
| def gradient_checkpointing_disable(self): | |
| self.phi_model.gradient_checkpointing_disable() | |
| def __init__(self, phi_model, tokenizer, projection): | |
| super().__init__(phi_model.config) | |
| self.phi_model = phi_model | |
| self.image_projection = projection | |
| self.tokenizer = tokenizer | |
| self.base_phi_model = None | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, debug=False, **kwargs): | |
| model_name = "microsoft/Phi-3.5-mini-instruct" | |
| base_phi_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| phi_path = pretrained_model_name_or_path | |
| model = PeftModel.from_pretrained(base_phi_model, phi_path) | |
| phi_model = model.merge_and_unload() | |
| input_dim = 512 | |
| output_dim = 3072 | |
| projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth") | |
| if os.path.exists(projector_path): | |
| projector_state_dict = torch.load(projector_path, map_location=phi_model.device) | |
| projector = ProjectionBlock(input_dim, output_dim) | |
| projector.load_state_dict(projector_state_dict, strict=False) | |
| print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}") | |
| else: | |
| print(f"Projector weights not found at {projector_path}. Initializing with default dimensions.") | |
| input_dim = 512 | |
| output_dim = phi_model.config.hidden_size | |
| projector = ProjectionBlock(input_dim, output_dim) | |
| model = cls(phi_model, tokenizer, projector) | |
| model.base_phi_model = base_phi_model | |
| return model | |
| def save_pretrained(self, save_directory): | |
| self.phi_model.save_pretrained(save_directory) | |
| projector_path = os.path.join(save_directory, "image_projector.pth") | |
| torch.save(self.image_projection.state_dict(), projector_path) | |
| self.config.save_pretrained(save_directory) | |
| def encode(self, image_features): | |
| image_projections = self.image_projection(image_features) | |
| return image_projections | |
| def forward(self, start_input_ids, end_input_ids, image_features, attention_mask, labels): | |
| device = next(self.parameters()).device | |
| start_embeds = self.phi_model.get_input_embeddings()(start_input_ids.to(device)) | |
| end_embeds = self.phi_model.get_input_embeddings()(end_input_ids.to(device)) | |
| if image_features is not None: | |
| image_embeddings = self.encode(image_features.to(device)).bfloat16() | |
| input_embeds = torch.cat([start_embeds, image_embeddings, end_embeds], dim=1) | |
| else: | |
| input_embeds = torch.cat([start_embeds, end_embeds], dim=1) | |
| outputs = self.phi_model(inputs_embeds=input_embeds.to(device), | |
| attention_mask=attention_mask.to(device), | |
| labels=labels, | |
| return_dict=True) | |
| return outputs |