Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| def fake_face_collator(batch): | |
| """The data collator for training vision transformer models on fake and real face dataset | |
| Args: | |
| batch (list): A dictionary containing the pixel values and the labels | |
| Returns: | |
| dict: The final dictionary | |
| """ | |
| new_batch = { | |
| 'pixel_values': [], | |
| 'labels': [] | |
| } | |
| for x in batch: | |
| pixel_values = torch.from_numpy(x['pixel_values'][0]) if isinstance(x['pixel_values'][0], np.ndarray) \ | |
| else x['pixel_values'][0] | |
| new_batch['pixel_values'].append(pixel_values) | |
| new_batch['labels'].append(torch.tensor(x['labels'])) | |
| new_batch['pixel_values'] = torch.stack(new_batch['pixel_values']) | |
| new_batch['labels'] = torch.stack(new_batch['labels']) | |
| return new_batch | |