|
|
import torch.nn as nn |
|
|
|
|
|
from .text_tower import Text_Tower |
|
|
|
|
|
class Knowledge_Encoder(nn.Module): |
|
|
def __init__(self, biolord_checkpoint='FremyCompany/BioLORD-2023-C'): |
|
|
super().__init__() |
|
|
|
|
|
self.text_tower = Text_Tower(biolord_checkpoint) |
|
|
self.projection_layer = nn.Sequential( |
|
|
nn.Linear(768, 768), |
|
|
nn.GELU(), |
|
|
nn.Linear(768, 768) |
|
|
) |
|
|
self.modality_embed = nn.Embedding(5, 768) |
|
|
|
|
|
def forward(self, text, modality, device): |
|
|
text_feature = self.text_tower(text, device) |
|
|
proj_text_feature = self.projection_layer(text_feature) |
|
|
|
|
|
modality_feature = self.modality_embed(modality) |
|
|
|
|
|
text_feature = text_feature + modality_feature |
|
|
proj_text_feature = proj_text_feature + modality_feature |
|
|
|
|
|
|
|
|
return proj_text_feature |