|
|
""" |
|
|
Evaluation script - test trained models with sample reactions. |
|
|
""" |
|
|
from transformers import pipeline |
|
|
from config import FORWARD_MODEL_NAME, RETRO_MODEL_NAME |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Test forward and retro models.""" |
|
|
print("=" * 60) |
|
|
print("ORD Reaction Model Evaluation") |
|
|
print("=" * 60) |
|
|
|
|
|
print(f"\nLoading forward model: {FORWARD_MODEL_NAME}") |
|
|
forward_pipe = pipeline("text2text-generation", model=FORWARD_MODEL_NAME) |
|
|
|
|
|
print(f"Loading retro model: {RETRO_MODEL_NAME}") |
|
|
retro_pipe = pipeline("text2text-generation", model=RETRO_MODEL_NAME) |
|
|
|
|
|
|
|
|
forward_examples = [ |
|
|
"CCO.COC(=O)C", |
|
|
"c1ccccc1.Br", |
|
|
"CC(=O)O.CCN", |
|
|
] |
|
|
|
|
|
retro_examples = [ |
|
|
"CCOC(=O)C", |
|
|
"c1ccccc1Br", |
|
|
"CC(=O)NCC", |
|
|
] |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("FORWARD SYNTHESIS TESTS") |
|
|
print("=" * 60) |
|
|
|
|
|
for reactants in forward_examples: |
|
|
print(f"\nInput (Reactants): {reactants}") |
|
|
try: |
|
|
result = forward_pipe(reactants, max_length=128, num_beams=5) |
|
|
predicted = result[0]["generated_text"] |
|
|
print(f"Predicted (Products): {predicted}") |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("RETROSYNTHESIS TESTS") |
|
|
print("=" * 60) |
|
|
|
|
|
for product in retro_examples: |
|
|
print(f"\nInput (Product): {product}") |
|
|
try: |
|
|
result = retro_pipe(product, max_length=256, num_beams=10) |
|
|
predicted = result[0]["generated_text"] |
|
|
print(f"Predicted (Reactants): {predicted}") |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Evaluation complete!") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|