jonahkall's picture
Upload 51 files
4c346eb verified
import os
from typing import TYPE_CHECKING
from unittest.mock import patch
import httpx
import pytest
from ether0.clients import fetch_forward_rxn, fetch_purchasable, fetch_solubility
if TYPE_CHECKING:
from fastapi.testclient import TestClient
IN_GITHUB_ACTIONS: bool = os.getenv("GITHUB_ACTIONS") == "true"
@pytest.mark.parametrize(
("smiles", "purchasable"),
[
("CC(=O)OC1=CC=CC=C1C(=O)O", True),
("CCO", True),
("C1=CC=C(C=C1)C(=O)O", True),
("OCN1C=CC=C1C(=O)O", False),
],
)
def test_fetch_purchasable(
test_client: "TestClient", smiles: str, purchasable: bool
) -> None:
with patch.object(httpx, "post", test_client.post):
assert fetch_purchasable(smiles)[smiles] == purchasable
@pytest.mark.parametrize(
("smiles", "solubility"),
[
("CC(=O)OC1=CC=CC=C1C(=O)O", -2.5),
("O=C(NC1CCCC1)C(C1CC1)S1C(=N)C(C2=CC=NC3=CC=CC=C23)N=C1", -5.9),
],
)
def test_fetch_solubility(
test_client: "TestClient", smiles: str, solubility: float
) -> None:
with patch.object(httpx, "post", test_client.post):
result = fetch_solubility(smiles)
assert "solubility" in result
assert pytest.approx(result["solubility"], abs=0.1) == solubility
@pytest.mark.skipif(
IN_GITHUB_ACTIONS, reason="Requires loading MolTrans model, too large for CI"
)
@pytest.mark.parametrize(
("precursor", "product", "correct"),
[
pytest.param("CC=O.O=C1CCC1Cl>[Mg].CCOCC>", "CC(O)C1(O)CCC1Cl", True),
pytest.param(
"CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>",
"CCCN(C)C1CC(C)(C)NC1=O",
True,
),
pytest.param("CCCC=O.O=C1CC=C(Br)S1>[Mg].CCOCC>", "CCCC(O)C1=CCC(=O)S1", True),
pytest.param("CCCC=O.COC(=O)C1CC1Br>[Mg].CCOCC>", "CCCC(O)C1CC1C(=O)OC", True),
pytest.param(
"CCCC=O.NC1CCCNC1=O>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>",
"O=C1NCCCC1N1CNCCCC1=O",
True,
),
pytest.param("CC=O.O=C1CCC1Cl.[Mg].CCOCC", None, False, id="missing_arrow"),
pytest.param(
"CC=O.O=C1CCC1Cl > [Mg].CCOCC", None, False, id="space_in_reaction"
),
pytest.param("not a > reaction", None, False, id="invalid_reaction"),
pytest.param(
"CCCC=O.COC(=O)C1CC1Br>[Mg].CCOCC", None, False, id="trailing_arrow"
),
],
)
def test_fetch_forward_rxn(
test_client: "TestClient", precursor: str, product: str | None, correct: bool
) -> None:
with patch.object(httpx, "post", test_client.post):
result = fetch_forward_rxn(precursor)
if correct:
assert (
result.get("product") == product
), f"Failed to get expected {product=} in {result=}."
assert "error" not in result
else:
assert result.get(
"error"
), f"Expected an error given {precursor=} and {correct=}"
assert "syntax error" in result["error"].lower()
assert "product" not in result