basilboy commited on
Commit
6d761ae
·
verified ·
1 Parent(s): db8dda6

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +131 -0
model.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model.py - Simple transformer model for microbiome data
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from typing import Dict
8
+
9
+
10
+
11
+ class MicrobiomeTransformer(nn.Module):
12
+ """
13
+ Simple transformer model for microbiome OTU embeddings
14
+ Handles two types of embeddings with separate input projections
15
+ Returns per-embedding predictions with variable length output
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ input_dim_type1: int = 384,
21
+ input_dim_type2: int = 1536,
22
+ d_model: int = 512,
23
+ nhead: int = 8,
24
+ num_layers: int = 6,
25
+ dim_feedforward: int = 2048,
26
+ dropout: float = 0.1,
27
+ use_output_activation: bool = True
28
+ ):
29
+ super().__init__()
30
+
31
+ # Store activation flag
32
+ self.use_output_activation = use_output_activation
33
+
34
+ # Separate input projections for each embedding type
35
+ self.input_projection_type1 = nn.Linear(input_dim_type1, d_model)
36
+ self.input_projection_type2 = nn.Linear(input_dim_type2, d_model)
37
+
38
+ # Transformer encoder
39
+ encoder_layer = nn.TransformerEncoderLayer(
40
+ d_model=d_model,
41
+ nhead=nhead,
42
+ dim_feedforward=dim_feedforward,
43
+ dropout=dropout,
44
+ batch_first=True
45
+ )
46
+
47
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
48
+
49
+ # Output layers - per position
50
+ self.output_projection = nn.Linear(d_model, 1)
51
+
52
+
53
+ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
54
+ """
55
+ Args:
56
+ batch: Dict with:
57
+ - 'embeddings_type1': (batch_size, seq_len1, input_dim_type1)
58
+ - 'embeddings_type2': (batch_size, seq_len2, input_dim_type2)
59
+ - 'mask': (batch_size, seq_len1 + seq_len2) - combined mask
60
+ - 'type_indicators': (batch_size, seq_len1 + seq_len2) - which type each position is
61
+
62
+ Returns:
63
+ torch.Tensor: (batch_size, seq_len1 + seq_len2) - value per embedding position
64
+ """
65
+ embeddings_type1 = batch['embeddings_type1'] # (batch_size, seq_len1, input_dim_type1)
66
+ embeddings_type2 = batch['embeddings_type2'] # (batch_size, seq_len2, input_dim_type2)
67
+ mask = batch['mask'] # (batch_size, total_seq_len)
68
+ type_indicators = batch['type_indicators'] # (batch_size, total_seq_len) - 0 for type1, 1 for type2
69
+
70
+ # Project each type separately
71
+ x1 = self.input_projection_type1(embeddings_type1) # (batch_size, seq_len1, d_model)
72
+ x2 = self.input_projection_type2(embeddings_type2) # (batch_size, seq_len2, d_model)
73
+
74
+ # Concatenate along sequence dimension
75
+ x = torch.cat([x1, x2], dim=1) # (batch_size, total_seq_len, d_model)
76
+
77
+ # Transformer (mask padded tokens)
78
+ x = self.transformer(x, src_key_padding_mask=~mask) # (batch_size, total_seq_len, d_model)
79
+
80
+ # Output projection per position
81
+ output = self.output_projection(x) # (batch_size, total_seq_len, 1)
82
+
83
+
84
+ output = output.squeeze(-1) # (batch_size, total_seq_len)
85
+
86
+ # Mask out padded positions
87
+ output = output * mask.float()
88
+
89
+ return output
90
+
91
+
92
+ # Example usage
93
+ if __name__ == "__main__":
94
+ model = MicrobiomeTransformer(
95
+ input_dim_type1=384,
96
+ input_dim_type2=256,
97
+ d_model=512,
98
+ nhead=8,
99
+ num_layers=6
100
+ )
101
+
102
+ # Test with dummy data
103
+ batch_size = 4
104
+ seq_len1 = 60 # Type 1 embeddings
105
+ seq_len2 = 40 # Type 2 embeddings
106
+ total_len = seq_len1 + seq_len2
107
+
108
+ batch = {
109
+ 'embeddings_type1': torch.randn(batch_size, seq_len1, 384),
110
+ 'embeddings_type2': torch.randn(batch_size, seq_len2, 256),
111
+ 'mask': torch.ones(batch_size, total_len, dtype=torch.bool),
112
+ 'type_indicators': torch.cat([
113
+ torch.zeros(batch_size, seq_len1, dtype=torch.long), # Type 1
114
+ torch.ones(batch_size, seq_len2, dtype=torch.long) # Type 2
115
+ ], dim=1)
116
+ }
117
+
118
+ # Add some padding
119
+ batch['mask'][:, 80:] = False
120
+
121
+ output = model(batch)
122
+ print(f"Output shape: {output.shape}") # Should be (4, 100)
123
+ print(f"Type 1 output shape: {output[:, :seq_len1].shape}") # (4, 60)
124
+ print(f"Type 2 output shape: {output[:, seq_len1:seq_len1+seq_len2].shape}") # (4, 40)
125
+
126
+ # Check that padded positions are zeroed
127
+ print(f"Padded positions sum: {output[:, 80:].sum().item()}") # Should be 0
128
+
129
+ # Check active positions
130
+ active_output = output[:, :80]
131
+ print(f"Active output range: {active_output.min().item():.3f} to {active_output.max().item():.3f}")