TaskCLIP / models /CoCoTask_Model.py
HanningChen
Initial HF Space: FastAPI + HTML (no weights yet)
f2f112a
import torch
from .Transformer import TransformerDecoderLayer
from .Transformer import TransformerDecoder
from .ScoreFunction import ScoreFunction
class CoCoTask_Model(torch.nn.Module):
def __init__(self,
num_layers=3,
norm=None,
return_intermediate=False,
d_model = 1024,
nhead = 8,
dim_feedforward=2048,
dropout=0.1,
N_words=10,
activation="relu",
normalize_before=False,
device = "cuda:0") -> None:
super().__init__()
self.decoder_norm = torch.nn.LayerNorm(d_model)
self.decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)
self.decoder = TransformerDecoder(self.decoder_layer, num_layers, self.decoder_norm, return_intermediate=return_intermediate)
self.ScoreFunction = ScoreFunction(N_words=N_words)
self.threshold = 0.35
def forward(self, tgt, memory, norm=False):
tgt_new, memory_new = self.decoder(tgt,memory,None)
score_raw = torch.mm(tgt_new,memory_new.T)
#NOTE: TODO check later to see if imagebind need softmax
if norm:
score_raw = torch.softmax(score_raw)
#score_res = self.MLP(score_raw)
score_res = self.ScoreFunction(score_raw)
return tgt_new, memory_new, score_res