Spaces:
Runtime error
Runtime error
| import torch | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Any | |
| from transformers import DataCollatorWithPadding, BatchEncoding | |
| def _unpack_doc_values(features: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| doc_examples = [] | |
| for f in features: | |
| keys = list(f.keys()) | |
| lists_per_key = len(f[keys[0]]) | |
| for idx in range(lists_per_key): | |
| doc_examples.append({k: f[k][idx] for k in keys}) | |
| return doc_examples | |
| class BiencoderCollator(DataCollatorWithPadding): | |
| def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding: | |
| q_prefix, d_prefix = 'q_', 'd_' | |
| query_examples = [{k[len(q_prefix):]: v for k, v in f.items() if k.startswith(q_prefix)} for f in features] | |
| doc_examples = _unpack_doc_values( | |
| [{k[len(d_prefix):]: v for k, v in f.items() if k.startswith(d_prefix)} for f in features]) | |
| assert len(doc_examples) % len(query_examples) == 0, \ | |
| '{} doc and {} queries'.format(len(doc_examples), len(query_examples)) | |
| # already truncated during tokenization | |
| q_collated = self.tokenizer.pad( | |
| query_examples, | |
| padding=self.padding, | |
| pad_to_multiple_of=self.pad_to_multiple_of, | |
| return_tensors=self.return_tensors) | |
| d_collated = self.tokenizer.pad( | |
| doc_examples, | |
| padding=self.padding, | |
| pad_to_multiple_of=self.pad_to_multiple_of, | |
| return_tensors=self.return_tensors) | |
| # merge into a single BatchEncoding by adding prefix | |
| for k in list(q_collated.keys()): | |
| q_collated[q_prefix + k] = q_collated[k] | |
| del q_collated[k] | |
| for k in d_collated: | |
| q_collated[d_prefix + k] = d_collated[k] | |
| merged_batch_dict = q_collated | |
| # dummy placeholder for field "labels", won't use it to compute loss | |
| labels = torch.zeros(len(query_examples), dtype=torch.long) | |
| merged_batch_dict['labels'] = labels | |
| if 'kd_labels' in features[0]: | |
| kd_labels = torch.stack([torch.tensor(f['kd_labels']) for f in features], dim=0).float() | |
| merged_batch_dict['kd_labels'] = kd_labels | |
| return merged_batch_dict | |