Update prediction.py
Browse files- prediction.py +16 -8
prediction.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
from dataloader import CellLoader
|
| 2 |
|
| 3 |
-
def
|
| 4 |
sequence_input,
|
| 5 |
nucleus_image,
|
|
|
|
| 6 |
model,
|
| 7 |
device
|
| 8 |
):
|
|
@@ -15,6 +16,7 @@ def run_image_prediction(
|
|
| 15 |
:param model_ckpt_path: Path to model checkpoint
|
| 16 |
:param model_config_path: Path to model config
|
| 17 |
"""
|
|
|
|
| 18 |
# Instantiate dataset object
|
| 19 |
dataset = CellLoader(
|
| 20 |
sequence_mode="embedding",
|
|
@@ -28,20 +30,26 @@ def run_image_prediction(
|
|
| 28 |
threshold="median",
|
| 29 |
)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# Convert SEQUENCE to sequence using dataset.tokenize_sequence()
|
| 32 |
sequence = dataset.tokenize_sequence(sequence_input)
|
| 33 |
|
| 34 |
# Sample from model using provided sequence and nucleus image
|
| 35 |
-
_,
|
| 36 |
text=sequence.to(device),
|
| 37 |
condition=nucleus_image.to(device),
|
| 38 |
-
|
|
|
|
| 39 |
temperature=1,
|
| 40 |
progress=False,
|
| 41 |
)
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
predicted_threshold = predicted_threshold.cpu()[0, 0]
|
| 45 |
-
predicted_heatmap = predicted_heatmap.cpu()[0, 0]
|
| 46 |
-
|
| 47 |
-
return predicted_threshold, predicted_heatmap
|
|
|
|
| 1 |
from dataloader import CellLoader
|
| 2 |
|
| 3 |
+
def run_sequence_prediction(
|
| 4 |
sequence_input,
|
| 5 |
nucleus_image,
|
| 6 |
+
protein_image,
|
| 7 |
model,
|
| 8 |
device
|
| 9 |
):
|
|
|
|
| 16 |
:param model_ckpt_path: Path to model checkpoint
|
| 17 |
:param model_config_path: Path to model config
|
| 18 |
"""
|
| 19 |
+
|
| 20 |
# Instantiate dataset object
|
| 21 |
dataset = CellLoader(
|
| 22 |
sequence_mode="embedding",
|
|
|
|
| 30 |
threshold="median",
|
| 31 |
)
|
| 32 |
|
| 33 |
+
# Check if sequence is provided and valid
|
| 34 |
+
if len(sequence_input) == 0:
|
| 35 |
+
raise ValueError("Sequence must be provided.")
|
| 36 |
+
|
| 37 |
+
if "<mask>" not in sequence_input:
|
| 38 |
+
print("Warning: Sequence does not contain any masked positions to predict.")
|
| 39 |
+
|
| 40 |
# Convert SEQUENCE to sequence using dataset.tokenize_sequence()
|
| 41 |
sequence = dataset.tokenize_sequence(sequence_input)
|
| 42 |
|
| 43 |
# Sample from model using provided sequence and nucleus image
|
| 44 |
+
_, predicted_sequence, _ = model.celle.sample_text(
|
| 45 |
text=sequence.to(device),
|
| 46 |
condition=nucleus_image.to(device),
|
| 47 |
+
image=protein_image.to(device),
|
| 48 |
+
force_aas=True,
|
| 49 |
temperature=1,
|
| 50 |
progress=False,
|
| 51 |
)
|
| 52 |
+
|
| 53 |
+
os.chdir(base_dir)
|
| 54 |
|
| 55 |
+
return predicted_sequence
|
|
|
|
|
|
|
|
|
|
|
|