Sadjad Alikhani
commited on
Update inference.py
Browse files- inference.py +2 -2
inference.py
CHANGED
|
@@ -25,7 +25,7 @@ warnings.filterwarnings('ignore')
|
|
| 25 |
|
| 26 |
def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
|
| 27 |
|
| 28 |
-
dataset =
|
| 29 |
# Process data through LWM
|
| 30 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
| 31 |
print(f'LWM loss: {lwm_loss:.4f}')
|
|
@@ -38,7 +38,7 @@ def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
|
|
| 38 |
dataset = embedding_data.float()
|
| 39 |
return dataset
|
| 40 |
|
| 41 |
-
def
|
| 42 |
|
| 43 |
input_ids, masked_tokens, masked_pos = zip(*data)
|
| 44 |
|
|
|
|
| 25 |
|
| 26 |
def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
|
| 27 |
|
| 28 |
+
dataset = prepare_for_lwm(preprocessed_chs, device)
|
| 29 |
# Process data through LWM
|
| 30 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
| 31 |
print(f'LWM loss: {lwm_loss:.4f}')
|
|
|
|
| 38 |
dataset = embedding_data.float()
|
| 39 |
return dataset
|
| 40 |
|
| 41 |
+
def prepare_for_lwm(data, device, batch_size=64, shuffle=False):
|
| 42 |
|
| 43 |
input_ids, masked_tokens, masked_pos = zip(*data)
|
| 44 |
|