fix target context to Lenselink due to trained checkpoint
Browse files- src/dataset.py +10 -5
src/dataset.py
CHANGED
|
@@ -39,16 +39,16 @@ class DrugRetrieval(Dataset):
|
|
| 39 |
self.remove_batch = True
|
| 40 |
|
| 41 |
assert os.path.exists(os.path.join(self.data_path, f'processed/{drug_encoder}_encoding.pickle')), 'Drug embeddings not available.'
|
| 42 |
-
assert os.path.exists(
|
| 43 |
|
| 44 |
# Drugs
|
| 45 |
-
emb_dict = self.
|
| 46 |
self.drug_ids = list(emb_dict.keys())
|
| 47 |
self.drug_embeddings = list(emb_dict.values())
|
| 48 |
|
| 49 |
# Context
|
| 50 |
self.target_scaler = StandardScaler()
|
| 51 |
-
context = self.
|
| 52 |
self.context = self.standardize(embeddings=context)
|
| 53 |
|
| 54 |
# Query target
|
|
@@ -71,8 +71,13 @@ class DrugRetrieval(Dataset):
|
|
| 71 |
def __len__(self):
|
| 72 |
return len(self.drug_ids)
|
| 73 |
|
| 74 |
-
def
|
| 75 |
-
with open(os.path.join(self.data_path, f'processed/{encoder_name}_encoding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
embeddings = pickle.load(handle)
|
| 77 |
return embeddings
|
| 78 |
|
|
|
|
| 39 |
self.remove_batch = True
|
| 40 |
|
| 41 |
assert os.path.exists(os.path.join(self.data_path, f'processed/{drug_encoder}_encoding.pickle')), 'Drug embeddings not available.'
|
| 42 |
+
assert os.path.exists(f'data/Lenselink/processed/{target_encoder}_encoding_train.pickle')), 'Context target embeddings not available.'
|
| 43 |
|
| 44 |
# Drugs
|
| 45 |
+
emb_dict = self.get_drug_embeddings(encoder_name=drug_encoder)
|
| 46 |
self.drug_ids = list(emb_dict.keys())
|
| 47 |
self.drug_embeddings = list(emb_dict.values())
|
| 48 |
|
| 49 |
# Context
|
| 50 |
self.target_scaler = StandardScaler()
|
| 51 |
+
context = self.get_target_embeddings(encoder_name=target_encoder)
|
| 52 |
self.context = self.standardize(embeddings=context)
|
| 53 |
|
| 54 |
# Query target
|
|
|
|
| 71 |
def __len__(self):
|
| 72 |
return len(self.drug_ids)
|
| 73 |
|
| 74 |
+
def get_drug_embeddings(self, encoder_name):
|
| 75 |
+
with open(os.path.join(self.data_path, f'processed/{encoder_name}_encoding.pickle'), 'rb') as handle:
|
| 76 |
+
embeddings = pickle.load(handle)
|
| 77 |
+
return embeddings
|
| 78 |
+
|
| 79 |
+
def get_target_embeddings(self, encoder_name):
|
| 80 |
+
with open(f'data/Lenselink/processed/{encoder_name}_encoding_train.pickle'), 'rb') as handle:
|
| 81 |
embeddings = pickle.load(handle)
|
| 82 |
return embeddings
|
| 83 |
|