Sadjad Alikhani commited on
Update inference.py
Browse files- inference.py +5 -29
inference.py
CHANGED
|
@@ -23,29 +23,9 @@ import numpy as np
|
|
| 23 |
import warnings
|
| 24 |
warnings.filterwarnings('ignore')
|
| 25 |
|
| 26 |
-
def
|
| 27 |
-
torch.manual_seed(seed)
|
| 28 |
-
np.random.seed(seed)
|
| 29 |
-
|
| 30 |
-
# Use this function at the start of your code
|
| 31 |
-
set_seed(42)
|
| 32 |
-
|
| 33 |
-
# Force model weights and data to float32 precision
|
| 34 |
-
def cast_model_weights_to_float32(model):
|
| 35 |
-
for param in model.parameters():
|
| 36 |
-
param.data = param.data.float() # Cast all weights to float32
|
| 37 |
-
return model
|
| 38 |
-
|
| 39 |
-
# Device configuration
|
| 40 |
-
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 41 |
-
if torch.cuda.is_available():
|
| 42 |
-
torch.cuda.empty_cache()
|
| 43 |
-
|
| 44 |
-
def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
| 45 |
|
| 46 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
| 47 |
-
|
| 48 |
-
lwm_model = cast_model_weights_to_float32(lwm_model)
|
| 49 |
# Process data through LWM
|
| 50 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
| 51 |
print(f'LWM loss: {lwm_loss:.4f}')
|
|
@@ -56,15 +36,14 @@ def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
|
| 56 |
embedding_data = embedding_data[:, 1:]
|
| 57 |
|
| 58 |
dataset = embedding_data.float()
|
| 59 |
-
print(dataset[0][:4])
|
| 60 |
return dataset
|
| 61 |
|
| 62 |
def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
|
| 63 |
|
| 64 |
input_ids, masked_tokens, masked_pos = zip(*data)
|
| 65 |
|
| 66 |
-
input_ids_tensor = torch.tensor(input_ids, device=device).float()
|
| 67 |
-
masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
|
| 68 |
masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
|
| 69 |
|
| 70 |
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
|
@@ -84,16 +63,13 @@ def evaluate(model, dataloader):
|
|
| 84 |
masked_tokens = batch[1]
|
| 85 |
masked_pos = batch[2]
|
| 86 |
|
| 87 |
-
if idx == 0:
|
| 88 |
-
print(input_ids[0])
|
| 89 |
-
|
| 90 |
logits_lm, output = model(input_ids, masked_pos)
|
| 91 |
|
| 92 |
output_batch_preproc = output
|
| 93 |
outputs.append(output_batch_preproc)
|
| 94 |
|
| 95 |
loss_lm = criterionMCM(logits_lm, masked_tokens)
|
| 96 |
-
loss = loss_lm / torch.var(masked_tokens)
|
| 97 |
running_loss += loss.item()
|
| 98 |
|
| 99 |
average_loss = running_loss / len(dataloader)
|
|
@@ -104,6 +80,6 @@ def evaluate(model, dataloader):
|
|
| 104 |
def create_raw_dataset(data, device):
|
| 105 |
"""Create a dataset for raw channel data."""
|
| 106 |
input_ids, _, _ = zip(*data)
|
| 107 |
-
input_data = torch.tensor(input_ids, device=device)
|
| 108 |
return input_data.float()
|
| 109 |
|
|
|
|
| 23 |
import warnings
|
| 24 |
warnings.filterwarnings('ignore')
|
| 25 |
|
| 26 |
+
def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
|
|
|
|
|
|
| 29 |
# Process data through LWM
|
| 30 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
| 31 |
print(f'LWM loss: {lwm_loss:.4f}')
|
|
|
|
| 36 |
embedding_data = embedding_data[:, 1:]
|
| 37 |
|
| 38 |
dataset = embedding_data.float()
|
|
|
|
| 39 |
return dataset
|
| 40 |
|
| 41 |
def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
|
| 42 |
|
| 43 |
input_ids, masked_tokens, masked_pos = zip(*data)
|
| 44 |
|
| 45 |
+
input_ids_tensor = torch.tensor(input_ids, device=device).float()
|
| 46 |
+
masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
|
| 47 |
masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
|
| 48 |
|
| 49 |
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
|
|
|
| 63 |
masked_tokens = batch[1]
|
| 64 |
masked_pos = batch[2]
|
| 65 |
|
|
|
|
|
|
|
|
|
|
| 66 |
logits_lm, output = model(input_ids, masked_pos)
|
| 67 |
|
| 68 |
output_batch_preproc = output
|
| 69 |
outputs.append(output_batch_preproc)
|
| 70 |
|
| 71 |
loss_lm = criterionMCM(logits_lm, masked_tokens)
|
| 72 |
+
loss = loss_lm / torch.var(masked_tokens)
|
| 73 |
running_loss += loss.item()
|
| 74 |
|
| 75 |
average_loss = running_loss / len(dataloader)
|
|
|
|
| 80 |
def create_raw_dataset(data, device):
|
| 81 |
"""Create a dataset for raw channel data."""
|
| 82 |
input_ids, _, _ = zip(*data)
|
| 83 |
+
input_data = torch.tensor(input_ids, device=device)[:, 1:]
|
| 84 |
return input_data.float()
|
| 85 |
|