File size: 1,938 Bytes
29a351f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
"""
Evaluation script - test trained models with sample reactions.
"""
from transformers import pipeline
from config import FORWARD_MODEL_NAME, RETRO_MODEL_NAME
def main():
"""Test forward and retro models."""
print("=" * 60)
print("ORD Reaction Model Evaluation")
print("=" * 60)
print(f"\nLoading forward model: {FORWARD_MODEL_NAME}")
forward_pipe = pipeline("text2text-generation", model=FORWARD_MODEL_NAME)
print(f"Loading retro model: {RETRO_MODEL_NAME}")
retro_pipe = pipeline("text2text-generation", model=RETRO_MODEL_NAME)
# Test examples
forward_examples = [
"CCO.COC(=O)C", # ethanol + methyl acetate
"c1ccccc1.Br", # benzene + bromine
"CC(=O)O.CCN", # acetic acid + ethylamine
]
retro_examples = [
"CCOC(=O)C", # ethyl acetate
"c1ccccc1Br", # bromobenzene
"CC(=O)NCC", # N-ethylacetamide
]
print("\n" + "=" * 60)
print("FORWARD SYNTHESIS TESTS")
print("=" * 60)
for reactants in forward_examples:
print(f"\nInput (Reactants): {reactants}")
try:
result = forward_pipe(reactants, max_length=128, num_beams=5)
predicted = result[0]["generated_text"]
print(f"Predicted (Products): {predicted}")
except Exception as e:
print(f"Error: {e}")
print("\n" + "=" * 60)
print("RETROSYNTHESIS TESTS")
print("=" * 60)
for product in retro_examples:
print(f"\nInput (Product): {product}")
try:
result = retro_pipe(product, max_length=256, num_beams=10)
predicted = result[0]["generated_text"]
print(f"Predicted (Reactants): {predicted}")
except Exception as e:
print(f"Error: {e}")
print("\n" + "=" * 60)
print("Evaluation complete!")
print("=" * 60)
if __name__ == "__main__":
main()
|