u8sand commited on
Commit
ccd7396
·
verified ·
1 Parent(s): e92e958

Update gsfm.py

Browse files
Files changed (1) hide show
  1. gsfm.py +18 -14
gsfm.py CHANGED
@@ -7,6 +7,10 @@ from huggingface_hub import PyTorchModelHubMixin, HfApi, hf_hub_download
7
  UNK_IDX, PAD_IDX = 0, 1
8
  special_symbols = ['<unk>', '<pad>']
9
 
 
 
 
 
10
  class Vocab:
11
  def __init__(self, vocab, default_index=0):
12
  self.vocab = vocab
@@ -91,21 +95,22 @@ class GSFM(
91
  PyTorchModelHubMixin,
92
  tags=["gene", "gene set", "bioinformatics"],
93
  ):
94
- def __init__(self, vocab_size, d_model=256, depth=2):
95
  super().__init__()
96
  self.vocab_size = vocab_size
97
  self.d_model = d_model
98
  self.depth = depth
99
- self.embedding = torch.nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
100
- self.encoder = MLP(*[d_model**n for n in range(1, depth)], d_model)
101
- self.decoder = MLP(d_model*2, *[d_model**n for n in range(2, depth)], vocab_size)
 
 
102
  self.save_hyperparameters()
103
 
104
  def encode(self, x):
105
- x = emb = self.embedding(x)
106
- x = enc = self.encoder(emb)
107
- x = torch.cat([enc.mean(1), emb.mean(1)], -1)
108
- return x
109
 
110
  def forward(self, x):
111
  x = self.encode(x)
@@ -113,12 +118,11 @@ class GSFM(
113
  return x
114
 
115
  def training_step(self, batch, batch_idx):
116
- x, y = batch
117
- is_x = torch.isnan(y)
118
- y = torch.where(is_x, 0, y)
119
- pos_weight = torch.where(is_x, 0, 1)
120
- y_ = self(x)
121
- criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
122
  loss = criterion(y_, y)
123
  self.log('loss', loss, prog_bar=True)
124
  return loss
 
7
  UNK_IDX, PAD_IDX = 0, 1
8
  special_symbols = ['<unk>', '<pad>']
9
 
10
+ def multihot_tensor(indices: torch.Tensor, num_classes: int, dtype=torch.int64, device=None):
11
+ *bs, _ = indices.shape
12
+ return torch.zeros((*bs, num_classes,), device=device, dtype=dtype).scatter(1, indices, 1)
13
+
14
  class Vocab:
15
  def __init__(self, vocab, default_index=0):
16
  self.vocab = vocab
 
95
  PyTorchModelHubMixin,
96
  tags=["gene", "gene set", "bioinformatics"],
97
  ):
98
+ def __init__(self, vocab_size, d_model=256, depth=2, dropout=0.2, partition=0, weighted_loss=None):
99
  super().__init__()
100
  self.vocab_size = vocab_size
101
  self.d_model = d_model
102
  self.depth = depth
103
+ self.dropout = dropout
104
+ self.partition = partition
105
+ self.weighted_loss = weighted_loss
106
+ self.encoder = MLP(vocab_size, *[d_model*(2**(n-1)) for n in range(depth, 1, -1)], d_model, dropout=dropout)
107
+ self.decoder = MLP(d_model, *[d_model*(2**(n-1)) for n in range(1, depth)], vocab_size, dropout=dropout)
108
  self.save_hyperparameters()
109
 
110
  def encode(self, x):
111
+ x = multihot_tensor(x, num_classes=self.vocab_size, device=self.device, dtype=torch.float)
112
+ x[:, PAD_IDX] = 0
113
+ return self.encoder(x)
 
114
 
115
  def forward(self, x):
116
  x = self.encode(x)
 
118
  return x
119
 
120
  def training_step(self, batch, batch_idx):
121
+ x_idx = y_idx = batch
122
+ y_ = self(x_idx)
123
+ y = multihot_tensor(y_idx, num_classes=self.vocab_size, device=self.device, dtype=torch.float)
124
+ y[:, PAD_IDX] = 0
125
+ criterion = torch.nn.BCEWithLogitsLoss()
 
126
  loss = criterion(y_, y)
127
  self.log('loss', loss, prog_bar=True)
128
  return loss