BAPULM / data /dataset.py
Moreza009's picture
Upload folder using huggingface_hub
15c5ffb verified
import json
import torch
from torch.utils.data import Dataset
class BindingAffinityDataset(Dataset):
def __init__(self, json_file):
self.data = []
with open(json_file, 'r') as f:
try:
self.data = json.load(f)
except json.JSONDecodeError as e:
print(f"Error reading JSON file: {e}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
protein_embedding = torch.tensor(item['prot_embedding'], dtype=torch.float32).squeeze(0)
chemical_embedding = torch.tensor(item['mol_embedding'], dtype=torch.float32).squeeze(0)
affinity = torch.tensor(item['affinity'], dtype=torch.float32)
return protein_embedding, chemical_embedding, affinity