Sadjad Alikhani
commited on
Update input_preprocess.py
Browse files- input_preprocess.py +2 -22
input_preprocess.py
CHANGED
|
@@ -16,19 +16,6 @@ import pickle
|
|
| 16 |
import DeepMIMOv3
|
| 17 |
import torch
|
| 18 |
|
| 19 |
-
def set_random_seed(seed=42):
|
| 20 |
-
torch.manual_seed(seed)
|
| 21 |
-
np.random.seed(seed)
|
| 22 |
-
#random.seed(seed)
|
| 23 |
-
if torch.cuda.is_available():
|
| 24 |
-
torch.cuda.manual_seed_all(seed)
|
| 25 |
-
# Ensures deterministic behavior
|
| 26 |
-
torch.backends.cudnn.deterministic = True
|
| 27 |
-
torch.backends.cudnn.benchmark = False
|
| 28 |
-
|
| 29 |
-
# Apply random seed
|
| 30 |
-
set_random_seed()
|
| 31 |
-
|
| 32 |
#%% Scenarios List
|
| 33 |
def scenarios_list():
|
| 34 |
"""Returns an array of available scenarios."""
|
|
@@ -208,7 +195,6 @@ def get_parameters(scenario):
|
|
| 208 |
|
| 209 |
return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
|
| 210 |
|
| 211 |
-
|
| 212 |
#%% Sample Generation
|
| 213 |
def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
|
| 214 |
"""
|
|
@@ -226,7 +212,6 @@ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_ra
|
|
| 226 |
Returns:
|
| 227 |
sample (list): Generated sample for the user.
|
| 228 |
"""
|
| 229 |
-
set_random_seed()
|
| 230 |
|
| 231 |
tokens = patch[user_idx]
|
| 232 |
input_ids = np.vstack((word2id['[CLS]'], tokens))
|
|
@@ -246,8 +231,7 @@ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_ra
|
|
| 246 |
input_ids[pos] = np.random.rand(patch_size)
|
| 247 |
elif rnd_num < 0.9:
|
| 248 |
input_ids[pos] = word2id['[MASK]']
|
| 249 |
-
|
| 250 |
-
# print(f'masked_pos: {masked_pos}')
|
| 251 |
return [input_ids, masked_tokens, masked_pos]
|
| 252 |
|
| 253 |
|
|
@@ -323,8 +307,7 @@ def load_var(path):
|
|
| 323 |
|
| 324 |
return var
|
| 325 |
|
| 326 |
-
#%%
|
| 327 |
-
|
| 328 |
def label_gen(task, data, scenario, n_beams=64):
|
| 329 |
|
| 330 |
idxs = np.where(data['user']['LoS'] != -1)[0]
|
|
@@ -364,13 +347,10 @@ def label_gen(task, data, scenario, n_beams=64):
|
|
| 364 |
return label.astype(int)
|
| 365 |
|
| 366 |
def steering_vec(array, phi=0, theta=0, kd=np.pi):
|
| 367 |
-
# phi = azimuth
|
| 368 |
-
# theta = elevation
|
| 369 |
idxs = DeepMIMOv3.ant_indices(array)
|
| 370 |
resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
|
| 371 |
return resp / np.linalg.norm(resp)
|
| 372 |
|
| 373 |
-
|
| 374 |
def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
|
| 375 |
labels = []
|
| 376 |
for scenario_idx in scenario_idxs:
|
|
|
|
| 16 |
import DeepMIMOv3
|
| 17 |
import torch
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
#%% Scenarios List
|
| 20 |
def scenarios_list():
|
| 21 |
"""Returns an array of available scenarios."""
|
|
|
|
| 195 |
|
| 196 |
return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
|
| 197 |
|
|
|
|
| 198 |
#%% Sample Generation
|
| 199 |
def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
|
| 200 |
"""
|
|
|
|
| 212 |
Returns:
|
| 213 |
sample (list): Generated sample for the user.
|
| 214 |
"""
|
|
|
|
| 215 |
|
| 216 |
tokens = patch[user_idx]
|
| 217 |
input_ids = np.vstack((word2id['[CLS]'], tokens))
|
|
|
|
| 231 |
input_ids[pos] = np.random.rand(patch_size)
|
| 232 |
elif rnd_num < 0.9:
|
| 233 |
input_ids[pos] = word2id['[MASK]']
|
| 234 |
+
|
|
|
|
| 235 |
return [input_ids, masked_tokens, masked_pos]
|
| 236 |
|
| 237 |
|
|
|
|
| 307 |
|
| 308 |
return var
|
| 309 |
|
| 310 |
+
#%% Label Generation
|
|
|
|
| 311 |
def label_gen(task, data, scenario, n_beams=64):
|
| 312 |
|
| 313 |
idxs = np.where(data['user']['LoS'] != -1)[0]
|
|
|
|
| 347 |
return label.astype(int)
|
| 348 |
|
| 349 |
def steering_vec(array, phi=0, theta=0, kd=np.pi):
|
|
|
|
|
|
|
| 350 |
idxs = DeepMIMOv3.ant_indices(array)
|
| 351 |
resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
|
| 352 |
return resp / np.linalg.norm(resp)
|
| 353 |
|
|
|
|
| 354 |
def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
|
| 355 |
labels = []
|
| 356 |
for scenario_idx in scenario_idxs:
|