saracandu commited on
Commit
a9d64ee
·
verified ·
1 Parent(s): 142a1a8

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -625
utils.py DELETED
@@ -1,625 +0,0 @@
1
- import ast
2
- import copy
3
- import math
4
- from typing import List, Optional, Tuple, Union
5
-
6
- import numpy as np
7
- import pandas as pd
8
- import torch
9
- import torch.utils.checkpoint
10
- from torch import nn
11
- import torch.nn.functional as F
12
- from torch.utils.data import Dataset
13
-
14
- from transformers.modeling_utils import PreTrainedModel
15
- from configuration import STLConfig
16
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
17
-
18
- import copy
19
- import pickle
20
- import os
21
- from collections import deque
22
-
23
- from stl import *
24
-
25
- from nltk.translate.bleu_score import sentence_bleu
26
- from handcoded_tokenizer import STLTokenizer
27
-
28
- import networkx as nx
29
- import phis_generator
30
-
31
- from datasets import load_dataset
32
-
33
- ############################################################################################################################
34
-
35
- def load_pickle(path):
36
- with open(path, 'rb') as f:
37
- x = pickle.load(f)
38
- return x
39
-
40
-
41
- def dump_pickle(name, thing):
42
- with open(name + '.pickle', 'wb') as f:
43
- pickle.dump(thing, f)
44
-
45
-
46
- def set_time_thresholds(st):
47
- unbound, right_unbound = [True, False]
48
- left_time_bound, right_time_bound = [0, 0]
49
- if st[-1] == ']':
50
- unbound = False
51
- time_thresholds = st[st.index('[')+1:-1].split(",")
52
- left_time_bound = int(time_thresholds[0])
53
- if time_thresholds[1] == 'inf':
54
- right_unbound = True
55
- else:
56
- right_time_bound = int(time_thresholds[1])-1
57
- return unbound, right_unbound, left_time_bound, right_time_bound
58
-
59
-
60
- def from_string_to_formula(st):
61
- root_arity = 2 if st.startswith('(') else 1
62
- st_split = st.split()
63
- if root_arity <= 1:
64
- root_op_str = copy.deepcopy(st_split[0])
65
- if root_op_str.startswith('x'):
66
- atom_sign = True if st_split[1] == '<=' else False
67
- root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2]))
68
- return root_phi
69
- else:
70
- assert (root_op_str.startswith('not') or root_op_str.startswith('eventually')
71
- or root_op_str.startswith('always'))
72
- current_st = copy.deepcopy(st_split[2:-1])
73
- if root_op_str == 'not':
74
- root_phi = Not(child=from_string_to_formula(' '.join(current_st)))
75
- elif root_op_str.startswith('eventually'):
76
- unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
77
- root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
78
- right_unbound=right_unbound, left_time_bound=left_time_bound,
79
- right_time_bound=right_time_bound)
80
- else:
81
- unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
82
- root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
83
- right_unbound=right_unbound, left_time_bound=left_time_bound,
84
- right_time_bound=right_time_bound)
85
- else:
86
- # 1 - delete everything which is contained in other sets of parenthesis (if any)
87
- current_st = copy.deepcopy(st_split[1:-1])
88
- if '(' in current_st:
89
- par_queue = deque()
90
- par_idx_list = []
91
- for i, sub in enumerate(current_st):
92
- if sub == '(':
93
- par_queue.append(i)
94
- elif sub == ')':
95
- par_idx_list.append(tuple([par_queue.pop(), i]))
96
- # open_par_idx, close_par_idx = [current_st.index(p) for p in ['(', ')']]
97
- # union of parentheses range --> from these we may extract the substrings to be the children!!!
98
- children_range = []
99
- for begin, end in sorted(par_idx_list):
100
- if children_range and children_range[-1][1] >= begin - 1:
101
- children_range[-1][1] = max(children_range[-1][1], end)
102
- else:
103
- children_range.append([begin, end])
104
- n_children = len(children_range)
105
- assert (n_children in [1, 2])
106
- if n_children == 1:
107
- # one of the children is a variable --> need to individuate it
108
- var_child_idx = 1 if children_range[0][0] <= 1 else 0 # 0 is left child, 1 is right child
109
- if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
110
- children_range[0][0] -= 1
111
- left_child_str = current_st[:3] if var_child_idx == 0 else \
112
- current_st[children_range[0][0]:children_range[0][1] + 1]
113
- right_child_str = current_st[-3:] if var_child_idx == 1 else \
114
- current_st[children_range[0][0]:children_range[0][1] + 1]
115
- root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \
116
- current_st[children_range[0][0] - 1]
117
- assert (root_op_str[:2] in ['an', 'or', 'un'])
118
- else:
119
- if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
120
- children_range[0][0] -= 1
121
- if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']:
122
- children_range[1][0] -= 1
123
- # if there are two children, with parentheses, the element in the middle is the root
124
- root_op_str = current_st[children_range[0][1] + 1]
125
- assert (root_op_str[:2] in ['an', 'or', 'un'])
126
- left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1]
127
- right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1]
128
- else:
129
- # no parentheses means that both children are variables
130
- left_child_str = current_st[:3]
131
- right_child_str = current_st[-3:]
132
- root_op_str = current_st[3]
133
- left_child_str = ' '.join(left_child_str)
134
- right_child_str = ' '.join(right_child_str)
135
- if root_op_str == 'and':
136
- root_phi = And(left_child=from_string_to_formula(left_child_str),
137
- right_child=from_string_to_formula(right_child_str))
138
- elif root_op_str == 'or':
139
- root_phi = Or(left_child=from_string_to_formula(left_child_str),
140
- right_child=from_string_to_formula(right_child_str))
141
- else:
142
- unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
143
- root_phi = Until(left_child=from_string_to_formula(left_child_str),
144
- right_child=from_string_to_formula(right_child_str),
145
- unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound,
146
- right_time_bound=right_time_bound)
147
- return root_phi
148
-
149
-
150
- def scale_trajectories(traj):
151
- traj_min = torch.min(torch.min(traj, dim=0)[0], dim=0)[0]
152
- traj_max = torch.max(torch.max(traj, dim=0)[0], dim=0)[0]
153
- scaled_traj = -1 + 2*(traj - traj_min) / (traj_max - traj_min)
154
- return scaled_traj
155
-
156
-
157
- def standardize_trajectories(traj_data, n_var):
158
- means, stds = [[] for _ in range(2)]
159
- for i in range(n_var):
160
- means.append(torch.mean(traj_data[:, i, :]))
161
- stds.append(torch.std(traj_data[:, i, :]))
162
- for i in range(n_var):
163
- traj_data[:, i, :] = (traj_data[:, i, :] - means[i]) / stds[i]
164
- return traj_data
165
-
166
- ############################################################################################################################
167
-
168
- class STLSinusoidalPositionalEmbedding(nn.Embedding):
169
- """This module produces sinusoidal positional embeddings of any length."""
170
-
171
- def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
172
- super().__init__(num_positions, embedding_dim)
173
- self.weight = self._init_weight(self.weight)
174
-
175
- @staticmethod
176
- def _init_weight(out: nn.Parameter) -> nn.Parameter:
177
- """
178
- Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
179
- the 2nd half of the vector. [dim // 2:]
180
- """
181
- n_pos, dim = out.shape
182
- position_enc = np.array(
183
- [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
184
- )
185
- out.requires_grad = False # set early to avoid an error in pytorch-1.8+
186
- sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
187
- out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
188
- out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
189
- out.detach_()
190
- return out
191
-
192
- @torch.no_grad()
193
- def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
194
- """`input_ids_shape` is expected to be [bsz x seqlen]."""
195
- bsz, seq_len = input_ids_shape[:2]
196
- positions = torch.arange(
197
- past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
198
- )
199
- return super().forward(positions)
200
-
201
- class STLAttention(nn.Module):
202
- """ Multi-Head Attention as depicted from 'Attention is all you need' """
203
-
204
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0,
205
- is_decoder: bool = False, bias: bool = False, is_causal: bool = False):
206
-
207
- super().__init__()
208
- self.embed_dim = embed_dim # overall embedding dimension -> to be divided between multiple heads
209
- self.num_heads = num_heads
210
- self.dropout = dropout
211
- self.head_dim = embed_dim // num_heads
212
- assert (self.head_dim * num_heads) == self.embed_dim
213
- self.scaling = self.head_dim ** -0.5 # used to normalize values when projected using `W_` matrices
214
- self.is_decoder = is_decoder
215
- self.is_causal = is_causal
216
-
217
- # 'roleplaying' matrices
218
- self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias)
219
- self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias)
220
- self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias)
221
-
222
- # to project the heads' outputs into a single vector
223
- self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias)
224
-
225
-
226
- def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
227
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
228
-
229
-
230
- def forward(self,
231
- hidden_states: torch.Tensor, # previous values, passed to the multi-head attn layer
232
- key_value_states: Optional[torch.Tensor] = None, # different key, value items (used in cross-attn)
233
- past_key_value: Optional[Tuple[torch.Tensor]] = None, # stores the key and values of previous steps
234
- attention_mask: Optional[torch.Tensor] = None, # masks non-allowed items (padded or future ones)
235
- layer_head_mask: Optional[torch.Tensor] = None, # used to de-activate specific attn heads
236
- output_attentions: bool = False # flag to control the output of the attn values
237
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
238
-
239
- is_cross_attention = key_value_states is not None # cross-attn if key_value_states is not None
240
-
241
- batch_size, tgt_len, embed_dim = hidden_states.size()
242
-
243
- # Project the current input in the `query` role:
244
- query = self.W_q(hidden_states) * self.scaling
245
-
246
- if (is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]):
247
- key = past_key_value[0]
248
- value = past_key_value[1]
249
- elif is_cross_attention:
250
- key = self._shape(self.W_k(key_value_states), -1, batch_size)
251
- value = self._shape(self.W_v(key_value_states), -1, batch_size)
252
- elif past_key_value is not None:
253
- key = self._shape(self.W_k(hidden_states), -1, batch_size)
254
- value = self._shape(self.W_v(hidden_states), -1, batch_size)
255
- key = torch.cat([past_key_value[0], key], dim=2)
256
- value = torch.cat([past_key_value[1], value], dim=2)
257
- else:
258
- key = self._shape(self.W_k(hidden_states), -1, batch_size)
259
- value = self._shape(self.W_v(hidden_states), -1, batch_size)
260
-
261
- if self.is_decoder:
262
- past_key_value = (key, value)
263
-
264
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
265
-
266
- query = self._shape(query, tgt_len, batch_size).view(*proj_shape)
267
- key = key.reshape(*proj_shape)
268
- value = value.reshape(*proj_shape)
269
-
270
- src_len = key.size(1)
271
-
272
-
273
- ######################################################################################################
274
-
275
- # 'traditional' attention computation
276
- # i.e. softmax(Q*K^T / sqrt(d_model) + self_attn_mask) * V
277
-
278
- # Batch-wise matrix multiplication between `query` and (TRANSPOSED) `key`
279
- attn_weights = torch.bmm(query, key.transpose(1, 2))
280
-
281
- if attention_mask is not None:
282
- attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask
283
- attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
284
-
285
- # Normalize values on the `key` axis (dim=-1)
286
- attn_weights = F.softmax(attn_weights, dim=-1)
287
-
288
- # if layer_head_mask is not None:
289
- # attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(batch_size, self.num_heads, tgt_len, src_len)
290
- # attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
291
-
292
- attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
293
-
294
- # Batch-wise matrix multiplication between the resulting probs and the value
295
- attn_output = torch.bmm(attn_probs, value)
296
-
297
- ######################################################################################################
298
-
299
- attn_output = attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim)
300
- attn_output = attn_output.transpose(1, 2)
301
-
302
- attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim)
303
- attn_output = self.W_o(attn_output)
304
-
305
- return attn_output, None, past_key_value
306
-
307
-
308
- class DatasetProcessor:
309
- def __init__(self, dataset_name, split="train", device="cuda" if torch.cuda.is_available() else "cpu"):
310
- self.device = device
311
- self.original_dataset = pd.read_pickle(dataset_name) # Load the dataset from the pickle file
312
- self.processed_dataset = self._create_processed_dataset()
313
-
314
- def _create_processed_dataset(self):
315
- # Transform a single entry
316
- def transform_entry(entry):
317
- # Convert 'Embedding' from string to list of floats if necessary
318
- formula_embedding = entry['Embedding512']
319
- encoder_hidden_states = torch.tensor(formula_embedding, dtype=torch.float32).to(self.device)
320
-
321
- # Convert 'Encoded_Formula' from string to list of integers
322
- encoded_formula = entry['Encoded_Formula']
323
- input_ids = encoded_formula[:-1] # All tokens except the last
324
- labels = encoded_formula[1:] # All tokens except the first
325
- attention_mask = [0 if token == 1 else 1 for token in input_ids]
326
-
327
- input_ids = torch.tensor(input_ids, dtype=torch.long).to(self.device)
328
- labels = torch.tensor(labels, dtype=torch.long).to(self.device)
329
- attention_mask = torch.tensor(attention_mask, dtype=torch.long).to(self.device)
330
-
331
-
332
- # Return only the transformed columns
333
- return {
334
- 'input_ids': input_ids,
335
- 'labels': labels,
336
- 'attention_mask': attention_mask,
337
- 'encoder_hidden_states': encoder_hidden_states
338
- }
339
-
340
- # Apply the transformation to each row in the dataset using pandas .apply()
341
- transformed_data = self.original_dataset.apply(transform_entry, axis=1)
342
-
343
- return transformed_data
344
-
345
- def get_processed_dataset(self):
346
- return self.processed_dataset
347
-
348
-
349
- # Create a `CustomDataset` class to properly format input data with respect to
350
- # the `input_ids`, `labels`, and `attention_mask` attributes for model training.
351
- class CustomDataset(Dataset):
352
- def __init__(self, df, device='cpu'):
353
- """
354
- Initializes the dataset by storing the DataFrame and setting the device.
355
-
356
- Args:
357
- - df: A pandas DataFrame containing the data (e.g., `Encoded_Formula`, `Embedding`).
358
- - device: The device ('cpu' or 'cuda') where the tensors will be moved for processing.
359
- """
360
- self.df = df['train']
361
- self.device = device
362
-
363
- encoded_formulae = []
364
- formulae_embeddings = []
365
- input_ids = []
366
- labels = []
367
- attention_masks = []
368
-
369
- for idx in range(len(self.df)):
370
- # Extract the encoded formula (tokenized input sequence) from the DataFrame
371
- # encoded_formula = self.df['Encoded_Formula'][idx]
372
- # Convert the string representation of a list back to a Python list using ast.literal_eval
373
- encoded_formula = ast.literal_eval(self.df['Encoded_Formula'][idx])
374
- # encoded_formula = [int(x) for x in encoded_formula.split()]
375
- encoded_formulae.append(encoded_formula)
376
-
377
- # Extract the precomputed formula embedding (hidden states) from the DataFrame
378
- formula_embedding = self.df['Embedding'][idx]
379
-
380
- # Clean the string and convert it back to a tensor
381
- # formula_embedding = formula_embedding.replace("tensor(", "").rstrip(")")
382
- # formula_embedding = eval(formula_embedding)
383
- formula_embedding = ast.literal_eval(formula_embedding.strip())
384
- encoder_hidden_states = torch.tensor(formula_embedding, dtype=torch.float32).to(self.device)
385
- formulae_embeddings.append(encoder_hidden_states)
386
-
387
- # Define the input_ids by excluding the last token (shifted tokens for prediction)
388
- input_ids.append(torch.tensor(encoded_formula[:-1], dtype=torch.long).to(self.device)) # All tokens except the last
389
- # Define the labels by excluding the first token (shifted tokens for teacher forcing)
390
- labels.append(torch.tensor(encoded_formula[1:], dtype=torch.long).to(self.device)) # All tokens except the first
391
-
392
- # Create the attention mask to indicate which tokens should be attended to.
393
- # Tokens equal to '1' (typically padding tokens) will be masked (set to 0),
394
- # and the rest will be visible (set to 1).
395
- attention_mask = [0 if token == 1 else 1 for token in encoded_formula[:-1]] # Use encoded_formula for mask
396
- attention_mask = torch.tensor(attention_mask, dtype=torch.long).to(self.device)
397
- attention_masks.append(attention_mask)
398
-
399
- # Create the DataFrame with the processed tensors
400
- self.df = {
401
- 'input_ids': input_ids,
402
- 'labels': labels,
403
- 'attention_mask': attention_masks,
404
- 'encoder_hidden_states': formulae_embeddings
405
- }
406
-
407
- # self.df = pd.DataFrame(temp, device=device)
408
-
409
- def __len__(self):
410
- """
411
- Returns the length of the dataset, i.e., the number of examples in the DataFrame.
412
-
413
- Returns:
414
- - Length of the DataFrame (number of samples).
415
- """
416
- return len(self.df)
417
-
418
- def __getitem__(self, idx):
419
- """
420
- Retrieves the dataset item at the given index.
421
-
422
- Args:
423
- - idx: The index of the sample to retrieve.
424
-
425
- Returns:
426
- - A dictionary containing the input data for the model.
427
- """
428
- return {
429
- 'input_ids': self.df['input_ids'][idx],
430
- 'labels': self.df['labels'][idx],
431
- 'attention_mask': self.df['attention_mask'][idx],
432
- 'encoder_hidden_states': self.df['encoder_hidden_states'][idx]
433
- }
434
-
435
- ############################################################################################################################
436
-
437
- # METRICS
438
-
439
- def token_division(input_string):
440
- tokenizer = STLTokenizer('tokenizer_files/tokenizer.json')
441
- return [element for element in tokenizer.tokenize(input_string) if element != "pad"]
442
-
443
-
444
-
445
- def bleu_score(dataset):
446
-
447
- bleu_scores = []
448
-
449
- for idx in range(len(dataset)):
450
- gold = token_division(dataset["Gold Formula"][idx])
451
- generated = token_division(dataset["Generated Formula"][idx])
452
-
453
- bleu_scores.append(sentence_bleu(gold, generated))
454
-
455
- return np.min(bleu_scores), np.mean(bleu_scores), np.max(bleu_scores)
456
-
457
-
458
-
459
- def exact_match(dataset, gold_formula_column: str, generated_formula_column: str):
460
-
461
- percentage = []
462
-
463
- for idx in range(len(dataset)):
464
- gold = token_division(dataset[gold_formula_column][idx])
465
- generated = token_division(dataset[generated_formula_column][idx])
466
-
467
- match_count = 0
468
- for gold_token, gen_token in zip(gold, generated):
469
- if gold_token == gen_token:
470
- match_count += 1
471
-
472
- percentage.append(match_count/len(gold))
473
-
474
-
475
- return np.mean(percentage)
476
-
477
-
478
-
479
- def cosine_similarity(dataset):
480
-
481
- similarities = []
482
-
483
- for idx in range(len(dataset)):
484
- gold = ast.literal_eval(dataset["Embedding Gold Formula"][idx])
485
- gen = ast.literal_eval(dataset["Embedding Generated Formula"][idx])
486
-
487
- dot_product = np.dot(gold, gen)
488
- gold_norm = np.linalg.norm(gold)
489
- gen_norm = np.linalg.norm(gen)
490
-
491
- similarities.append(dot_product / (gold_norm * gen_norm))
492
-
493
- return np.min(similarities), np.mean(similarities), np.max(similarities)
494
-
495
-
496
- def euclidean_distance(dataset):
497
-
498
- distances = []
499
-
500
- for idx in range(len(dataset)):
501
-
502
- gold = torch.tensor(ast.literal_eval(dataset["Embedding Gold Formula"][idx]))
503
- generated = torch.tensor(ast.literal_eval(dataset["Embedding Generated Formula"][idx]))
504
-
505
- distances.append(torch.dist(gold, generated))
506
-
507
- return np.min(distances), np.mean(distances), np.max(distances)
508
-
509
-
510
- #######################################################################################################
511
-
512
- def get_name_given_type(formula):
513
- """
514
- Returns the type of node (as a string) of the top node of the formula/sub-formula
515
- """
516
- name_dict = {And: 'and', Or: 'or', Not: 'not', Eventually: 'F', Globally: 'G', Until: 'U',
517
- Atom: 'x'}
518
- return name_dict[type(formula)]
519
-
520
-
521
- def get_id(child_name, name, label_dict, idx):
522
- """
523
- Get unique identifier for a node
524
- """
525
- while child_name in label_dict.keys(): # if the name is already present
526
- idx += 1
527
- child_name = name + "(" + str(idx) + ")"
528
- return child_name, idx # returns both the child name and the identifier
529
-
530
-
531
- def get_temporal_list(temporal_node):
532
- """
533
- Returns the features vector for temporal nodes (the two bounds of the temporal interval)
534
- Variant and num_arg modify the length of the list to return (3, 4 or 5)
535
- """
536
- left = float(temporal_node.left_time_bound) if temporal_node.unbound is False else 0.
537
- right = float(temporal_node.right_time_bound) if (temporal_node.unbound is False and
538
- temporal_node.right_unbound is False) else -1.
539
- vector_l = [left, right, 0.] # third slot for sign and fourth for threshold # add another slot for argument number
540
- return vector_l
541
-
542
-
543
- def add_internal_child(current_child, current_idx, label_dict):
544
- child_name = get_name_given_type(current_child) + '(' + str(current_idx) + ')'
545
- child_name, current_idx = get_id(child_name, get_name_given_type(current_child), label_dict, current_idx)
546
- return child_name, current_idx
547
-
548
-
549
- def add_leaf_child(node, name, label_dict, idx):
550
- """
551
- Add the edges and update the label_dictionary and the identifier count for a leaf node (variable)
552
- variant = ['original', 'threshold-sign', 'all-in-var']
553
- shared_var = [True, False] denotes if shared variables for all the DAG or single variables (tree-like)
554
- num_arg = [True, False] if true argument number is one-hot encoded in the feature vector
555
- until_right is a flag to detect when the argument number encoding should be 1
556
- """
557
- new_e = []
558
- label_dict[name] = [0., 0., 0.] # te
559
- atom_idx =str(node).split()[0] + '(' + str(idx) + ')'
560
- # different names for the same variables (e.g. x_1(5), x_1(8))
561
- idx += 1
562
- if atom_idx not in label_dict.keys():
563
- label_dict[atom_idx] = [0., 0., 0.]
564
-
565
- if str(node).split()[1] == '<=':
566
- label_dict[name] = [0., 0., round(node.threshold, 4)]
567
- else:
568
- label_dict[name] = [0., 0., round(node.threshold, 4)]
569
- new_e.append([name, atom_idx])
570
- return new_e, label_dict, idx+1
571
-
572
-
573
- def traverse_formula(formula, idx, label_dict):
574
- current_node = formula
575
- edges = []
576
- if type(current_node) is not Atom:
577
- current_name = get_name_given_type(current_node) + '(' + str(idx) + ')'
578
- if (type(current_node) is And) or (type(current_node) is Or) or (type(current_node) is Not):
579
- label_dict[current_name] = [0., 0., 0. ] # temp_left, temp_right, threshold
580
- else:
581
- label_dict[current_name] = get_temporal_list(current_node)
582
- if (type(current_node) is And) or (type(current_node) is Or) or (type(current_node) is Until):
583
- left_child_name, current_idx = add_internal_child(current_node.left_child, idx + 1, label_dict)
584
- edges.append([current_name, left_child_name])
585
- if type(current_node.left_child) is Atom:
586
- e, d, current_idx = add_leaf_child(current_node.left_child, left_child_name, label_dict, current_idx+1)
587
- edges += e
588
- label_dict.update(d)
589
- e, d = traverse_formula(current_node.left_child, current_idx, label_dict)
590
- edges += e
591
- label_dict.update(d)
592
- right_child_name, current_idx = add_internal_child(current_node.right_child, current_idx + 1, label_dict)
593
- edges.append([current_name, right_child_name])
594
- if type(current_node.right_child) is Atom:
595
- e, d, current_idx = add_leaf_child(current_node.right_child, right_child_name, label_dict,
596
- current_idx+1)
597
- edges += e
598
- label_dict.update(d)
599
- e, d = traverse_formula(current_node.right_child, current_idx, label_dict)
600
- edges += e
601
- label_dict.update(d)
602
- else:
603
- # eventually, globally, not
604
- child_name, current_idx = add_internal_child(current_node.child, idx + 1, label_dict)
605
- edges.append([current_name, child_name])
606
- if type(current_node.child) is Atom:
607
- e, d, current_idx = add_leaf_child(current_node.child, child_name, label_dict, current_idx+1)
608
- edges += e
609
- label_dict.update(d)
610
- e, d = traverse_formula(current_node.child, current_idx, label_dict)
611
- edges += e
612
- label_dict.update(d)
613
- return edges, label_dict
614
-
615
-
616
- def build_dag(formula):
617
- edges, label_dict = traverse_formula(formula, 0, {})
618
- graph = nx.from_edgelist(edges, create_using=nx.DiGraph)
619
- assert(nx.is_directed_acyclic_graph(graph))
620
- return graph, label_dict
621
-
622
-
623
- def get_depth(formula):
624
- phi_g = build_dag(formula)[0]
625
- return len(nx.dag_longest_path(phi_g)) - 1