File size: 916 Bytes
7f3dfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch.nn as nn
from .text_tower import Text_Tower
class Knowledge_Encoder(nn.Module):
def __init__(self, biolord_checkpoint='FremyCompany/BioLORD-2023-C'):
super().__init__()
# LP
self.text_tower = Text_Tower(biolord_checkpoint)
self.projection_layer = nn.Sequential(
nn.Linear(768, 768),
nn.GELU(),
nn.Linear(768, 768)
)
self.modality_embed = nn.Embedding(5, 768)
def forward(self, text, modality, device):
text_feature = self.text_tower(text, device)
proj_text_feature = self.projection_layer(text_feature)
modality_feature = self.modality_embed(modality)
text_feature = text_feature + modality_feature
proj_text_feature = proj_text_feature + modality_feature
# return text_feature, proj_text_feature
return proj_text_feature |