KingTechnician commited on
Commit
05c8ce6
·
verified ·
1 Parent(s): b0e5df5

Upload modeling_mimi.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_mimi.py +73 -0
modeling_mimi.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ class PositionalEncoding(nn.Module):
8
+ def __init__(self, d_model, max_len=15000):
9
+ super().__init__()
10
+ pe = torch.zeros(max_len, d_model)
11
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
12
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
13
+ pe[:, 0::2] = torch.sin(position * div_term)
14
+ pe[:, 1::2] = torch.cos(position * div_term)
15
+ self.register_buffer('pe', pe.unsqueeze(0))
16
+
17
+ def forward(self, x):
18
+ seq_len = x.size(1)
19
+ if seq_len > self.pe.size(1):
20
+ x = x[:, :self.pe.size(1), :]
21
+ seq_len = x.size(1)
22
+ return x + self.pe[:, :seq_len, :]
23
+
24
+ class MemoryModule(nn.Module):
25
+ def __init__(self, input_dim, mem_dim=64, num_slots=20):
26
+ super().__init__()
27
+ self.mem_dim = mem_dim
28
+ self.num_slots = num_slots
29
+ self.query_proj = nn.Linear(input_dim, mem_dim)
30
+ self.memory = nn.Parameter(torch.randn(num_slots, mem_dim))
31
+ nn.init.kaiming_uniform_(self.memory)
32
+
33
+ def forward(self, x):
34
+ q = self.query_proj(x)
35
+ att_logits = torch.matmul(q, self.memory.t())
36
+ att_weights = F.softmax(att_logits, dim=-1)
37
+ read_content = torch.matmul(att_weights, self.memory)
38
+ return read_content, att_weights
39
+
40
+ class MemoryTransformerDetector(nn.Module):
41
+ def __init__(self, rgb_dim=384, flow_dim=1024, audio_dim=768, d_model=256, nhead=4, num_layers=2):
42
+ super().__init__()
43
+ part_dim = d_model // 3
44
+ self.rgb_proj = nn.Linear(rgb_dim, part_dim)
45
+ self.flow_proj = nn.Linear(flow_dim, part_dim)
46
+ self.audio_proj = nn.Linear(audio_dim, d_model - (2 * part_dim))
47
+ self.pos_encoder = PositionalEncoding(d_model, max_len=15000)
48
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
49
+ dim_feedforward=512, dropout=0.3,
50
+ batch_first=True)
51
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
52
+ self.memory = MemoryModule(input_dim=d_model, mem_dim=d_model, num_slots=50)
53
+ self.classifier = nn.Sequential(
54
+ nn.Linear(d_model * 2, 128),
55
+ nn.ReLU(),
56
+ nn.Dropout(0.3),
57
+ nn.Linear(128, 1),
58
+ nn.Sigmoid()
59
+ )
60
+
61
+ def forward(self, rgb, flow, audio):
62
+ if rgb.dim() == 2:
63
+ rgb, flow, audio = rgb.unsqueeze(0), flow.unsqueeze(0), audio.unsqueeze(0)
64
+ x_rgb = self.rgb_proj(rgb)
65
+ x_flow = self.flow_proj(flow)
66
+ x_audio = self.audio_proj(audio)
67
+ x = torch.cat((x_rgb, x_flow, x_audio), dim=2)
68
+ x = self.pos_encoder(x)
69
+ x_trans = self.transformer(x)
70
+ x_mem, att_weights = self.memory(x_trans)
71
+ x_final = torch.cat((x_trans, x_mem), dim=2)
72
+ logits = self.classifier(x_final)
73
+ return logits.squeeze(2), att_weights