Upload tutorial.py
Browse files- tutorial.py +33 -3
tutorial.py
CHANGED
|
@@ -74,11 +74,13 @@ scenario_names = np.array([
|
|
| 74 |
scenario_idxs = np.array([0, 1, 2, 3, 4, 5])[3]
|
| 75 |
selected_scenario_names = scenario_names[scenario_idxs]
|
| 76 |
|
|
|
|
|
|
|
| 77 |
preprocessed_chs = tokenizer(
|
| 78 |
selected_scenario_names=selected_scenario_names,
|
| 79 |
manual_data=None,
|
| 80 |
gen_raw=True,
|
| 81 |
-
snr_db=
|
| 82 |
)
|
| 83 |
|
| 84 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
@@ -87,7 +89,7 @@ model = lwm.from_pretrained(device=device)
|
|
| 87 |
#%%
|
| 88 |
from inference import lwm_inference, create_raw_dataset
|
| 89 |
input_types = ['cls_emb', 'channel_emb', 'raw']
|
| 90 |
-
selected_input_type = input_types[
|
| 91 |
|
| 92 |
if selected_input_type in ['cls_emb', 'channel_emb']:
|
| 93 |
dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
|
|
@@ -139,7 +141,7 @@ for task in tasks:
|
|
| 139 |
#%% TRAINING
|
| 140 |
#%% TRAINING PARAMETERS
|
| 141 |
task = ['LoS/NLoS Classification', 'Beam Prediction'][0] # Select the task
|
| 142 |
-
n_trials =
|
| 143 |
num_classes = 2 if task == 'LoS/NLoS Classification' else n_beams # Set number of classes based on the task
|
| 144 |
input_types = ['raw', 'cls_emb'] # Types of input data
|
| 145 |
split_ratios = np.array([.005, .0075, .01, .015, .02, .03,
|
|
@@ -174,6 +176,20 @@ for input_type_idx, input_type in enumerate(input_types):
|
|
| 174 |
print(f"\ninput type: {input_type}, \nnumber of training samples: {int(split_ratio*len(dataset))}, \ntrial: {trial}\n")
|
| 175 |
|
| 176 |
torch.manual_seed(trial) # Set seed for reproducibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
train_loader, test_loader = get_data_loaders(
|
| 178 |
dataset,
|
| 179 |
labels,
|
|
@@ -240,6 +256,20 @@ for input_type_idx, input_type in enumerate(input_types):
|
|
| 240 |
for trial in range(n_trials):
|
| 241 |
|
| 242 |
torch.manual_seed(trial) # Set seed for reproducibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
train_loader, test_loader = get_data_loaders(
|
| 244 |
dataset,
|
| 245 |
labels,
|
|
|
|
| 74 |
scenario_idxs = np.array([0, 1, 2, 3, 4, 5])[3]
|
| 75 |
selected_scenario_names = scenario_names[scenario_idxs]
|
| 76 |
|
| 77 |
+
snr_db = None
|
| 78 |
+
|
| 79 |
preprocessed_chs = tokenizer(
|
| 80 |
selected_scenario_names=selected_scenario_names,
|
| 81 |
manual_data=None,
|
| 82 |
gen_raw=True,
|
| 83 |
+
snr_db=snr_db
|
| 84 |
)
|
| 85 |
|
| 86 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 89 |
#%%
|
| 90 |
from inference import lwm_inference, create_raw_dataset
|
| 91 |
input_types = ['cls_emb', 'channel_emb', 'raw']
|
| 92 |
+
selected_input_type = input_types[1]
|
| 93 |
|
| 94 |
if selected_input_type in ['cls_emb', 'channel_emb']:
|
| 95 |
dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
|
|
|
|
| 141 |
#%% TRAINING
|
| 142 |
#%% TRAINING PARAMETERS
|
| 143 |
task = ['LoS/NLoS Classification', 'Beam Prediction'][0] # Select the task
|
| 144 |
+
n_trials = 1 # Number of trials for each configuration
|
| 145 |
num_classes = 2 if task == 'LoS/NLoS Classification' else n_beams # Set number of classes based on the task
|
| 146 |
input_types = ['raw', 'cls_emb'] # Types of input data
|
| 147 |
split_ratios = np.array([.005, .0075, .01, .015, .02, .03,
|
|
|
|
| 176 |
print(f"\ninput type: {input_type}, \nnumber of training samples: {int(split_ratio*len(dataset))}, \ntrial: {trial}\n")
|
| 177 |
|
| 178 |
torch.manual_seed(trial) # Set seed for reproducibility
|
| 179 |
+
|
| 180 |
+
if snr_db is not None:
|
| 181 |
+
preprocessed_chs = tokenizer(
|
| 182 |
+
selected_scenario_names=selected_scenario_names,
|
| 183 |
+
manual_data=None,
|
| 184 |
+
gen_raw=True,
|
| 185 |
+
snr_db=snr_db
|
| 186 |
+
)
|
| 187 |
+
if input_type in ['cls_emb', 'channel_emb']:
|
| 188 |
+
dataset = lwm_inference(preprocessed_chs, input_type, model, device)
|
| 189 |
+
else:
|
| 190 |
+
dataset = create_raw_dataset(preprocessed_chs, device)
|
| 191 |
+
dataset = dataset.view(dataset.size(0), -1)
|
| 192 |
+
|
| 193 |
train_loader, test_loader = get_data_loaders(
|
| 194 |
dataset,
|
| 195 |
labels,
|
|
|
|
| 256 |
for trial in range(n_trials):
|
| 257 |
|
| 258 |
torch.manual_seed(trial) # Set seed for reproducibility
|
| 259 |
+
|
| 260 |
+
if snr_db is not None:
|
| 261 |
+
preprocessed_chs = tokenizer(
|
| 262 |
+
selected_scenario_names=selected_scenario_names,
|
| 263 |
+
manual_data=None,
|
| 264 |
+
gen_raw=True,
|
| 265 |
+
snr_db=snr_db
|
| 266 |
+
)
|
| 267 |
+
if input_type in ['cls_emb', 'channel_emb']:
|
| 268 |
+
dataset = lwm_inference(preprocessed_chs, input_type, model, device)
|
| 269 |
+
else:
|
| 270 |
+
dataset = create_raw_dataset(preprocessed_chs, device)
|
| 271 |
+
dataset = dataset.view(dataset.size(0), -1)
|
| 272 |
+
|
| 273 |
train_loader, test_loader = get_data_loaders(
|
| 274 |
dataset,
|
| 275 |
labels,
|