HaidarJomaa commited on
Commit
2174f5d
·
verified ·
1 Parent(s): e526c52

Upload 5 files

Browse files
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SpaceTimeMiniLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 384,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1536,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 6,
18
+ "num_space": 4,
19
+ "num_time": 60,
20
+ "pad_token_id": 0,
21
+ "position_embedding_type": "absolute",
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.51.3",
24
+ "type_vocab_size": 2,
25
+ "use_cache": true,
26
+ "use_space_embedding": true,
27
+ "use_time_embedding": true,
28
+ "vocab_size": 30522
29
+ }
dataset_space_time.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from transformers import AutoTokenizer, DataCollatorForLanguageModeling
4
+
5
+ SEQ_LEN, BATCH_SIZE = 128, 32
6
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
9
+
10
+ space_mapping = {'UK': 0, 'US': 1, 'AUS': 2, 'CAN': 3}
11
+ time_mapping = {
12
+ f"{year}-{month:02d}": i
13
+ for i, (year, month) in enumerate(
14
+ [(y, m) for y in range(2017, 2022 + 1) for m in range(1, 13)]
15
+ )
16
+ if i < 60
17
+ }
18
+
19
+ data_collator = DataCollatorForLanguageModeling(
20
+ tokenizer=tokenizer, mlm=True, mlm_probability=0.15
21
+ )
22
+
23
+ class PairwiseSimilarityDataset(Dataset):
24
+ def __init__(self, df):
25
+ self.df = df.reset_index(drop=True)
26
+ self.tokenizer = tokenizer
27
+ def __len__(self):
28
+ return len(self.df)
29
+ def __getitem__(self, idx):
30
+ row = self.df.loc[idx]
31
+ return {
32
+ "sent1": row.sent1,
33
+ "sent2": row.sent2,
34
+ "t1": time_mapping[row.t1],
35
+ "t2": time_mapping[row.t2],
36
+ "s1": space_mapping[row.s1],
37
+ "s2": space_mapping[row.s2],
38
+ "sim": row.similarity
39
+ }
40
+
41
+ def collate_fn(batch):
42
+ texts = [b["sent1"] for b in batch] + [b["sent2"] for b in batch]
43
+ enc = tokenizer(
44
+ texts,
45
+ padding="longest",
46
+ truncation=True,
47
+ max_length=128,
48
+ return_tensors="pt"
49
+ )
50
+ B = len(batch)
51
+ t1 = torch.tensor([b["t1"] for b in batch], dtype=torch.long)
52
+ t2 = torch.tensor([b["t2"] for b in batch], dtype=torch.long)
53
+ s1 = torch.tensor([b["s1"] for b in batch], dtype=torch.long)
54
+ s2 = torch.tensor([b["s2"] for b in batch], dtype=torch.long)
55
+ sims = torch.tensor([b["sim"] for b in batch], dtype=torch.float)
56
+ return enc, B, s1, s2, t1, t2, sims
inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ space_mapping = {'UK': 0, 'US': 1, 'AUS': 2, 'CAN': 3}
5
+ time_mapping = {
6
+ f"{year}-{month:02d}": i
7
+ for i, (year, month) in enumerate(
8
+ [(y, m) for y in range(2017, 2022 + 1) for m in range(1, 13)]
9
+ )
10
+ if i < 60
11
+ }
12
+
13
+ def compute_similarity(
14
+ sent1: str,
15
+ sent2: str,
16
+ time1: str,
17
+ time2: str,
18
+ space1: str,
19
+ space2: str,
20
+ model: None,
21
+ tokenizer,
22
+ device="cuda"
23
+ ) -> float:
24
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
25
+ model = model.to(device).eval()
26
+
27
+ enc1 = tokenizer(
28
+ sent1,
29
+ padding="max_length",
30
+ truncation=True,
31
+ max_length=128,
32
+ return_tensors="pt"
33
+ ).to(device)
34
+ enc2 = tokenizer(
35
+ sent2,
36
+ padding="max_length",
37
+ truncation=True,
38
+ max_length=128,
39
+ return_tensors="pt"
40
+ ).to(device)
41
+
42
+ space1 = space_mapping[space1]
43
+ space2 = space_mapping[space2]
44
+ time1 = time_mapping[time1]
45
+ time2 = time_mapping[time2]
46
+ s1 = torch.tensor([space1], dtype=torch.long, device=device)
47
+ t1 = torch.tensor([time1], dtype=torch.long, device=device)
48
+ s2 = torch.tensor([space2], dtype=torch.long, device=device)
49
+ t2 = torch.tensor([time2], dtype=torch.long, device=device)
50
+
51
+ with torch.no_grad():
52
+ emb1 = model.embed(
53
+ enc1["input_ids"],
54
+ enc1["attention_mask"],
55
+ s1, t1
56
+ )
57
+ emb2 = model.embed(
58
+ enc2["input_ids"],
59
+ enc2["attention_mask"],
60
+ s2, t2
61
+ )
62
+
63
+ sim = F.cosine_similarity(emb1, emb2, dim=-1)
64
+ return sim.item()
65
+
66
+ def embed_sentence(
67
+ sent: str,
68
+ time: str,
69
+ space: str,
70
+ model: None,
71
+ tokenizer,
72
+ device="cuda"
73
+ ) -> torch.Tensor:
74
+ device = torch.device(device if torch.cuda.is_available() else "cpu")
75
+ model = model.to(device).eval()
76
+
77
+ enc = tokenizer(
78
+ sent,
79
+ padding="max_length",
80
+ truncation=True,
81
+ max_length=128,
82
+ return_tensors="pt"
83
+ ).to(device)
84
+
85
+ space = space_mapping[space]
86
+ time = time_mapping[time]
87
+ s = torch.tensor([space], dtype=torch.long, device=device)
88
+ t = torch.tensor([time], dtype=torch.long, device=device)
89
+
90
+ with torch.no_grad():
91
+ emb = model.embed(
92
+ enc["input_ids"],
93
+ enc["attention_mask"],
94
+ s, t
95
+ )
96
+ return emb
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce31c3a4bd9044584c2a2c058a548ecc8dfebc3397a6ea3dd078ea347b99a6f8
3
+ size 145165416
modeling_custom_minilm.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
5
+ from transformers import PreTrainedModel
6
+ from transformers.models.bert.modeling_bert import BertSelfAttention
7
+
8
+ class SpaceEmbedding(nn.Module):
9
+ def __init__(self, num_embeddings=4, embedding_dim=384):
10
+ super().__init__()
11
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
12
+
13
+ def forward(self, x):
14
+ return self.embedding(x)
15
+
16
+ class TimeEmbedding(nn.Module):
17
+ def __init__(self, max_months, dim=384):
18
+ super().__init__()
19
+ self.dim = dim
20
+ pe = torch.zeros(max_months, dim)
21
+ pos = torch.arange(0, max_months).unsqueeze(1)
22
+ i = torch.arange(0, dim, 2)
23
+ pe[:, 0::2] = torch.sin(pos / (10000 ** (2*i/dim)))
24
+ pe[:, 1::2] = torch.cos(pos / (10000 ** (2*i/dim)))
25
+ self.register_buffer("pe", pe)
26
+
27
+ def forward(self, idx):
28
+ return self.pe[idx]
29
+
30
+ # ----------------------------
31
+ # 1) Custom Space–Time Attention
32
+ # ----------------------------
33
+ class SpaceTimeSelfAttention(nn.Module):
34
+ def __init__(self, orig_self: BertSelfAttention, config):
35
+ super().__init__()
36
+ self.orig = orig_self
37
+ self.config = config
38
+ self.W_t = nn.Linear(config.hidden_size, config.hidden_size)
39
+ self.W_s = nn.Linear(config.hidden_size, config.hidden_size)
40
+
41
+ def transpose_for_scores(self, x):
42
+ return self.orig.transpose_for_scores(x)
43
+
44
+ def forward(
45
+ self,
46
+ hidden_states,
47
+ attention_mask=None,
48
+ head_mask=None,
49
+ output_attentions=False,
50
+ time_embeddings=None,
51
+ space_embeddings=None,
52
+ ):
53
+
54
+ mixed_q = self.orig.query(hidden_states)
55
+ mixed_k = self.orig.key(hidden_states)
56
+ mixed_v = self.orig.value(hidden_states)
57
+
58
+ query_layer = self.transpose_for_scores(mixed_q)
59
+ key_layer = self.transpose_for_scores(mixed_k)
60
+ value_layer = self.transpose_for_scores(mixed_v)
61
+
62
+ T = self.W_t(time_embeddings)
63
+ S = self.W_s(space_embeddings)
64
+ T_layer = self.transpose_for_scores(T)
65
+ S_layer = self.transpose_for_scores(S)
66
+
67
+ base_scores = torch.matmul(
68
+ query_layer,
69
+ key_layer.transpose(-1, -2)
70
+ )
71
+
72
+ eps = 1e-6
73
+ T_norm = T_layer.norm(dim=-1, keepdim=True)
74
+ time_sim = torch.matmul(
75
+ T_layer,
76
+ T_layer.transpose(-1, -2)
77
+ ) / (T_norm + eps)
78
+
79
+ S_norm = S_layer.norm(dim=-1, keepdim=True)
80
+ space_sim = torch.matmul(
81
+ S_layer,
82
+ S_layer.transpose(-1, -2)
83
+ ) / (S_norm + eps)
84
+
85
+ attn_scores = base_scores * time_sim * space_sim
86
+
87
+ dk = self.config.hidden_size // self.config.num_attention_heads
88
+ attn_scores = attn_scores / math.sqrt(dk)
89
+
90
+ if attention_mask is not None:
91
+ attn_scores = attn_scores + attention_mask
92
+ attn_probs = nn.Softmax(dim=-1)(attn_scores)
93
+ attn_probs = self.orig.dropout(attn_probs)
94
+
95
+ if head_mask is not None:
96
+ attn_probs = attn_probs * head_mask
97
+
98
+ context = torch.matmul(attn_probs, value_layer)
99
+ context = context.permute(0, 2, 1, 3).contiguous()
100
+ new_shape = context.size()[:-2] + (self.config.hidden_size,)
101
+ context = context.view(*new_shape)
102
+
103
+ if output_attentions:
104
+ return (context, attn_probs)
105
+ return context
106
+
107
+
108
+ # ----------------------------
109
+ # 2) Full Space–Time–MiniLM Model
110
+ # ----------------------------
111
+ class SpaceTimeMiniLM(PreTrainedModel):
112
+ config_class = AutoConfig
113
+ def __init__(self, config):
114
+ super().__init__(config)
115
+ self.base = AutoModel.from_config(config)
116
+ self.config = config
117
+
118
+ for layer in self.base.encoder.layer:
119
+ orig_self = layer.attention.self
120
+ layer.attention.self = SpaceTimeSelfAttention(orig_self, self.config)
121
+
122
+ self.space_embed = SpaceEmbedding(num_embeddings=config.num_space,
123
+ embedding_dim=self.config.hidden_size)
124
+ self.time_embed = TimeEmbedding(max_months=config.num_time,
125
+ dim=self.config.hidden_size)
126
+
127
+ self.mlm_head = nn.Linear(self.config.hidden_size,
128
+ config.vocab_size)
129
+ self.space_head = nn.Linear(self.config.hidden_size, config.num_space)
130
+ self.time_head = nn.Linear(self.config.hidden_size, config.num_time)
131
+
132
+ def forward(self, input_ids, attention_mask, space_ids, time_ids):
133
+ B, L = input_ids.size()
134
+
135
+ extended_mask = self.base.get_extended_attention_mask(attention_mask, (B, L), device=input_ids.device)
136
+
137
+ emb = self.base.embeddings(input_ids)
138
+
139
+ S = self.space_embed(space_ids)
140
+ T = self.time_embed(time_ids)
141
+ S = S.unsqueeze(1).expand(-1, L, -1)
142
+ T = T.unsqueeze(1).expand(-1, L, -1)
143
+
144
+ hidden_states = emb
145
+ for layer in self.base.encoder.layer:
146
+ attn_out = layer.attention.self(
147
+ hidden_states,
148
+ attention_mask=extended_mask,
149
+ head_mask=None,
150
+ output_attentions=False,
151
+ time_embeddings=T,
152
+ space_embeddings=S
153
+ )
154
+ attn_out = layer.attention.output(attn_out, hidden_states)
155
+ interm = layer.intermediate(attn_out)
156
+ hidden_states = layer.output(interm, attn_out)
157
+
158
+ sequence_output = hidden_states
159
+ pooled_output = self.base.pooler(sequence_output)
160
+
161
+ mlm_logits = self.mlm_head(sequence_output)
162
+ space_logits = self.space_head(pooled_output)
163
+ time_logits = self.time_head(pooled_output)
164
+
165
+ return mlm_logits, space_logits, time_logits
166
+
167
+ def embed(
168
+ self,
169
+ input_ids: torch.LongTensor,
170
+ attention_mask: torch.LongTensor,
171
+ space_ids: torch.LongTensor,
172
+ time_ids: torch.LongTensor
173
+ ) -> torch.FloatTensor:
174
+ B, L = input_ids.size()
175
+
176
+ extended_mask = self.base.get_extended_attention_mask(
177
+ attention_mask, (B, L), device=input_ids.device
178
+ )
179
+
180
+ hidden_states = self.base.embeddings(input_ids)
181
+
182
+ S = self.space_embed(space_ids)
183
+ T = self.time_embed(time_ids)
184
+ S = S.unsqueeze(1).expand(-1, L, -1)
185
+ T = T.unsqueeze(1).expand(-1, L, -1)
186
+
187
+ for layer in self.base.encoder.layer:
188
+ attn_out = layer.attention.self(
189
+ hidden_states,
190
+ attention_mask=extended_mask,
191
+ head_mask=None,
192
+ output_attentions=False,
193
+ time_embeddings=T,
194
+ space_embeddings=S
195
+ )
196
+ attn_out = layer.attention.output(attn_out, hidden_states)
197
+ interm = layer.intermediate(attn_out)
198
+ hidden_states = layer.output(interm, attn_out)
199
+
200
+ mask_exp = attention_mask.unsqueeze(-1).expand_as(hidden_states).float()
201
+ sum_emb = torch.sum(hidden_states * mask_exp, dim=1)
202
+ sum_mask = mask_exp.sum(dim=1).clamp(min=1e-9)
203
+ pooled = sum_emb / sum_mask
204
+
205
+ return pooled