Spaces:
Build error
Build error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import AutoModel, AutoConfig | |
| from transformers import AutoTokenizer | |
| import pandas as pd | |
| AUTH_TOKEN = "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" | |
| tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base', | |
| use_auth_token=AUTH_TOKEN) | |
| pad_token_id = tokenizer.pad_token_id | |
| class PairwiseModel(nn.Module): | |
| def __init__(self, model_name, max_length=384, batch_size=16, device="cpu"): | |
| super(PairwiseModel, self).__init__() | |
| self.max_length = max_length | |
| self.batch_size = batch_size | |
| self.device = device | |
| self.model = AutoModel.from_pretrained(model_name, use_auth_token=AUTH_TOKEN) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.config = AutoConfig.from_pretrained(model_name, use_auth_token=AUTH_TOKEN) | |
| self.fc = nn.Linear(768, 1).to(self.device) | |
| def forward(self, ids, masks): | |
| out = self.model(input_ids=ids, | |
| attention_mask=masks, | |
| output_hidden_states=False).last_hidden_state | |
| out = out[:, 0] | |
| outputs = self.fc(out) | |
| return outputs | |
| def stage1_ranking(self, question, texts): | |
| tmp = pd.DataFrame() | |
| tmp["text"] = [" ".join(x.split()) for x in texts] | |
| tmp["question"] = question | |
| valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True) | |
| valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn, | |
| num_workers=0, shuffle=False, pin_memory=True) | |
| preds = [] | |
| with torch.no_grad(): | |
| bar = enumerate(valid_loader) | |
| for step, data in bar: | |
| ids = data["ids"].to(self.device) | |
| masks = data["masks"].to(self.device) | |
| preds.append(torch.sigmoid(self(ids, masks)).view(-1)) | |
| preds = torch.concat(preds) | |
| return preds.cpu().numpy() | |
| def stage2_ranking(self, question, answer, titles, texts): | |
| tmp = pd.DataFrame() | |
| tmp["candidate"] = texts | |
| tmp["question"] = question | |
| tmp["answer"] = answer | |
| tmp["title"] = titles | |
| valid_dataset = SiameseDatasetStage2(tmp, tokenizer, self.max_length, is_test=True) | |
| valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn, | |
| num_workers=0, shuffle=False, pin_memory=True) | |
| preds = [] | |
| with torch.no_grad(): | |
| bar = enumerate(valid_loader) | |
| for step, data in bar: | |
| ids = data["ids"].to(self.device) | |
| masks = data["masks"].to(self.device) | |
| preds.append(torch.sigmoid(self(ids, masks)).view(-1)) | |
| preds = torch.concat(preds) | |
| return preds.cpu().numpy() | |
| class SiameseDatasetStage1(Dataset): | |
| def __init__(self, df, tokenizer, max_length, is_test=False): | |
| self.df = df | |
| self.max_length = max_length | |
| self.tokenizer = tokenizer | |
| self.is_test = is_test | |
| self.content1 = tokenizer.batch_encode_plus(list(df.question.values), max_length=max_length, truncation=True)[ | |
| "input_ids"] | |
| self.content2 = tokenizer.batch_encode_plus(list(df.text.values), max_length=max_length, truncation=True)[ | |
| "input_ids"] | |
| if not self.is_test: | |
| self.targets = self.df.label | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, index): | |
| return { | |
| 'ids1': torch.tensor(self.content1[index], dtype=torch.long), | |
| 'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long), | |
| 'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float) | |
| } | |
| class SiameseDatasetStage2(Dataset): | |
| def __init__(self, df, tokenizer, max_length, is_test=False): | |
| self.df = df | |
| self.max_length = max_length | |
| self.tokenizer = tokenizer | |
| self.is_test = is_test | |
| self.df["content1"] = self.df.apply(lambda row: row.question + f" {tokenizer.sep_token} " + row.answer, axis=1) | |
| self.df["content2"] = self.df.apply(lambda row: row.title + f" {tokenizer.sep_token} " + row.candidate, axis=1) | |
| self.content1 = tokenizer.batch_encode_plus(list(df.content1.values), max_length=max_length, truncation=True)[ | |
| "input_ids"] | |
| self.content2 = tokenizer.batch_encode_plus(list(df.content2.values), max_length=max_length, truncation=True)[ | |
| "input_ids"] | |
| if not self.is_test: | |
| self.targets = self.df.label | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, index): | |
| return { | |
| 'ids1': torch.tensor(self.content1[index], dtype=torch.long), | |
| 'ids2': torch.tensor(self.content2[index][1:], dtype=torch.long), | |
| 'target': torch.tensor(0) if self.is_test else torch.tensor(self.targets[index], dtype=torch.float) | |
| } | |
| def collate_fn(batch): | |
| ids = [torch.cat([x["ids1"], x["ids2"]]) for x in batch] | |
| targets = [x["target"] for x in batch] | |
| max_len = np.max([len(x) for x in ids]) | |
| masks = [] | |
| for i in range(len(ids)): | |
| if len(ids[i]) < max_len: | |
| ids[i] = torch.cat((ids[i], torch.tensor([pad_token_id, ] * (max_len - len(ids[i])), dtype=torch.long))) | |
| masks.append(ids[i] != pad_token_id) | |
| # print(tokenizer.decode(ids[0])) | |
| outputs = { | |
| "ids": torch.vstack(ids), | |
| "masks": torch.vstack(masks), | |
| "target": torch.vstack(targets).view(-1) | |
| } | |
| return outputs | |