File size: 809 Bytes
15c5ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import json
import torch
from torch.utils.data import Dataset
class BindingAffinityDataset(Dataset):
def __init__(self, json_file):
self.data = []
with open(json_file, 'r') as f:
try:
self.data = json.load(f)
except json.JSONDecodeError as e:
print(f"Error reading JSON file: {e}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
protein_embedding = torch.tensor(item['prot_embedding'], dtype=torch.float32).squeeze(0)
chemical_embedding = torch.tensor(item['mol_embedding'], dtype=torch.float32).squeeze(0)
affinity = torch.tensor(item['affinity'], dtype=torch.float32)
return protein_embedding, chemical_embedding, affinity
|