M2-Encoder-1B / vlmo /modules /heads.py
malusama's picture
Upload safetensors export
ea0524d verified
raw
history blame contribute delete
665 Bytes
import torch.nn as nn
class Pooler(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class ITCHead(nn.Module):
def __init__(self, hidden_size, out_size):
super().__init__()
self.fc = nn.Linear(hidden_size, out_size, bias=False)
def forward(self, x):
x = self.fc(x)
return x