| from Multilingual_CLIP.multilingual_clip import Config_MCLIP |
| import transformers |
| import torch |
|
|
|
|
| class MultilingualCLIP(transformers.PreTrainedModel): |
| config_class = Config_MCLIP.MCLIPConfig |
|
|
| def __init__(self, config, *args, **kwargs): |
| super().__init__(config, *args, **kwargs) |
| self.transformer = transformers.AutoModel.from_pretrained(config.modelBase) |
| self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions, |
| out_features=config.numDims) |
|
|
| def forward(self, txt, tokenizer, device): |
| txt_tok = tokenizer(txt, padding='max_length', max_length=77, truncation=True, return_tensors='pt').to(device) |
| embs = self.transformer(**txt_tok) |
| print(embs.keys()) |
| embs = embs[0] |
| att = txt_tok['attention_mask'] |
| embs = (embs * att.unsqueeze(2)) / att.sum(dim=1)[:, None].unsqueeze(2) |
| return self.LinearTransformation(embs) |
|
|
| @classmethod |
| def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True): |
| model.load_state_dict(state_dict) |
| return model, [], [], [] |
|
|