File size: 916 Bytes
7f3dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.nn as nn

from .text_tower import Text_Tower

class Knowledge_Encoder(nn.Module):
    def __init__(self, biolord_checkpoint='FremyCompany/BioLORD-2023-C'):
        super().__init__()
        # LP
        self.text_tower = Text_Tower(biolord_checkpoint)
        self.projection_layer = nn.Sequential(
            nn.Linear(768, 768),
            nn.GELU(),
            nn.Linear(768, 768)
        )
        self.modality_embed = nn.Embedding(5, 768)
    
    def forward(self, text, modality, device):
        text_feature = self.text_tower(text, device)
        proj_text_feature = self.projection_layer(text_feature)
        
        modality_feature = self.modality_embed(modality)
        
        text_feature = text_feature + modality_feature
        proj_text_feature = proj_text_feature + modality_feature
        
        # return text_feature, proj_text_feature
        return proj_text_feature