Medal-S-V1.0 / model /knowledge_encoder.py
spc819's picture
Upload 69 files
7f3dfd7 verified
raw
history blame contribute delete
916 Bytes
import torch.nn as nn
from .text_tower import Text_Tower
class Knowledge_Encoder(nn.Module):
def __init__(self, biolord_checkpoint='FremyCompany/BioLORD-2023-C'):
super().__init__()
# LP
self.text_tower = Text_Tower(biolord_checkpoint)
self.projection_layer = nn.Sequential(
nn.Linear(768, 768),
nn.GELU(),
nn.Linear(768, 768)
)
self.modality_embed = nn.Embedding(5, 768)
def forward(self, text, modality, device):
text_feature = self.text_tower(text, device)
proj_text_feature = self.projection_layer(text_feature)
modality_feature = self.modality_embed(modality)
text_feature = text_feature + modality_feature
proj_text_feature = proj_text_feature + modality_feature
# return text_feature, proj_text_feature
return proj_text_feature