| import json | |
| import torch | |
| from torch.utils.data import Dataset | |
| class BindingAffinityDataset(Dataset): | |
| def __init__(self, json_file): | |
| self.data = [] | |
| with open(json_file, 'r') as f: | |
| try: | |
| self.data = json.load(f) | |
| except json.JSONDecodeError as e: | |
| print(f"Error reading JSON file: {e}") | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| protein_embedding = torch.tensor(item['prot_embedding'], dtype=torch.float32).squeeze(0) | |
| chemical_embedding = torch.tensor(item['mol_embedding'], dtype=torch.float32).squeeze(0) | |
| affinity = torch.tensor(item['affinity'], dtype=torch.float32) | |
| return protein_embedding, chemical_embedding, affinity | |