Add download-verified inference example
Browse files
README.md
CHANGED
|
@@ -71,3 +71,31 @@ If you use this model, please cite:
|
|
| 71 |
url={{https://huggingface.co/kojima-lab/molcrawl-compounds-chemberta2-large}}
|
| 72 |
}
|
| 73 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
url={{https://huggingface.co/kojima-lab/molcrawl-compounds-chemberta2-large}}
|
| 72 |
}
|
| 73 |
```
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## Example Output
|
| 77 |
+
|
| 78 |
+
End-to-end inference test (downloaded the model from this repo on CPU).
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
import torch
|
| 82 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 83 |
+
|
| 84 |
+
REPO_ID = "kojima-lab/molcrawl-compounds-chemberta2-large"
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
|
| 86 |
+
model = AutoModelForMaskedLM.from_pretrained(REPO_ID)
|
| 87 |
+
model.eval()
|
| 88 |
+
|
| 89 |
+
# SMILES with one masked position
|
| 90 |
+
prompt = "CC(=O)Oc1ccccc1[MASK](=O)O"
|
| 91 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 92 |
+
mask_index = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0]
|
| 93 |
+
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
outputs = model(**inputs)
|
| 96 |
+
|
| 97 |
+
predicted_id = outputs.logits[0, mask_index].argmax(dim=-1)
|
| 98 |
+
predicted_token = tokenizer.convert_ids_to_tokens(predicted_id.tolist())[0]
|
| 99 |
+
print(f"Predicted token at mask: {predicted_token}")
|
| 100 |
+
# => Predicted token at mask: C
|
| 101 |
+
```
|