Subh775 commited on
Commit
504c8fa
·
verified ·
1 Parent(s): 9460f33

Add region_model.py for self-contained custom code

Browse files
Files changed (1) hide show
  1. region_model.py +43 -0
region_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .fourier_features import FourierFeatures
4
+
5
+ class RegionModel(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ self.position_features = FourierFeatures(2, 256)
10
+ self.position_encoder = nn.Linear(256, 2048)
11
+ self.size_features = FourierFeatures(2, 256)
12
+ self.size_encoder = nn.Linear(256, 2048)
13
+
14
+ self.position_decoder = nn.Linear(2048, 2)
15
+ self.size_decoder = nn.Linear(2048, 2)
16
+ self.confidence_decoder = nn.Linear(2048, 1)
17
+
18
+ def encode_position(self, position):
19
+ return self.position_encoder(self.position_features(position))
20
+
21
+ def encode_size(self, size):
22
+ return self.size_encoder(self.size_features(size))
23
+
24
+ def decode_position(self, x):
25
+ return self.position_decoder(x)
26
+
27
+ def decode_size(self, x):
28
+ return self.size_decoder(x)
29
+
30
+ def decode_confidence(self, x):
31
+ return self.confidence_decoder(x)
32
+
33
+ def encode(self, position, size):
34
+ return torch.stack(
35
+ [self.encode_position(position), self.encode_size(size)], dim=0
36
+ )
37
+
38
+ def decode(self, position_logits, size_logits):
39
+ return (
40
+ self.decode_position(position_logits),
41
+ self.decode_size(size_logits),
42
+ self.decode_confidence(size_logits),
43
+ )