| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .text_encoder import Text_Encoder |
| from .resunet_film import UNetRes_FiLM |
|
|
| class LASSNet(nn.Module): |
| def __init__(self, device='cuda'): |
| super(LASSNet, self).__init__() |
| self.text_embedder = Text_Encoder(device) |
| self.UNet = UNetRes_FiLM(channels=1, cond_embedding_dim=256) |
|
|
| def forward(self, x, caption): |
| |
| input_ids, attns_mask = self.text_embedder.tokenize(caption) |
| |
| cond_vec = self.text_embedder(input_ids, attns_mask)[0] |
| dec_cond_vec = cond_vec |
|
|
| mask = self.UNet(x, cond_vec, dec_cond_vec) |
| mask = torch.sigmoid(mask) |
| return mask |
|
|
| def get_tokenizer(self): |
| return self.text_embedder.tokenizer |
|
|