Spaces:
Runtime error
Runtime error
Commit ·
dd58cce
1
Parent(s): 31f75a0
Update pairwise_model.py
Browse files- pairwise_model.py +8 -3
pairwise_model.py
CHANGED
|
@@ -8,7 +8,7 @@ from optimum.intel import OVModelForQuestionAnswering
|
|
| 8 |
import openvino.inference_engine as ie
|
| 9 |
import os
|
| 10 |
import gradio as gr
|
| 11 |
-
|
| 12 |
AUTH_TOKEN = "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"
|
| 13 |
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
|
|
@@ -35,7 +35,12 @@ class PairwiseModel_modify(nn.Module):
|
|
| 35 |
|
| 36 |
def forward(self, ids, masks):
|
| 37 |
# Export the model to ONNX format
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
# Specify the input shapes (batch_size, max_sequence_length)
|
| 40 |
input_shapes = {"input_ids": ids.shape, "attention_mask": masks.shape}
|
| 41 |
|
|
@@ -57,7 +62,7 @@ class PairwiseModel_modify(nn.Module):
|
|
| 57 |
tmp["question"] = question
|
| 58 |
valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
|
| 59 |
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
| 60 |
-
num_workers=
|
| 61 |
preds = []
|
| 62 |
with torch.no_grad():
|
| 63 |
bar = enumerate(valid_loader)
|
|
|
|
| 8 |
import openvino.inference_engine as ie
|
| 9 |
import os
|
| 10 |
import gradio as gr
|
| 11 |
+
from multiprocessing import cpu_count
|
| 12 |
AUTH_TOKEN = "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"
|
| 13 |
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
|
|
|
|
| 35 |
|
| 36 |
def forward(self, ids, masks):
|
| 37 |
# Export the model to ONNX format
|
| 38 |
+
ids_np = ids.cpu().numpy().astype(np.int64)
|
| 39 |
+
masks_np = masks.cpu().numpy().astype(np.int64)
|
| 40 |
+
ids_device = torch.from_numpy(ids_np).to(self.device)
|
| 41 |
+
masks_device = torch.from_numpy(masks_np).to(self.device)
|
| 42 |
+
|
| 43 |
+
input_feed = {"input_ids": ids_device, "attention_mask": masks_device}
|
| 44 |
# Specify the input shapes (batch_size, max_sequence_length)
|
| 45 |
input_shapes = {"input_ids": ids.shape, "attention_mask": masks.shape}
|
| 46 |
|
|
|
|
| 62 |
tmp["question"] = question
|
| 63 |
valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
|
| 64 |
valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
|
| 65 |
+
num_workers=cpu_count(), shuffle=False, pin_memory=True)
|
| 66 |
preds = []
|
| 67 |
with torch.no_grad():
|
| 68 |
bar = enumerate(valid_loader)
|