| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from utils import MorganFingerprint, morgan_fingerprint_to_text | |
| # Load the checkpoint and the tokenizer | |
| checkpoint_path = "lamthuy/MorganGen" | |
| model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path) | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) | |
| # Given a SMILES, get its fingerpint | |
| smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" | |
| m = MorganFingerprint() | |
| mf = m.smiles_to_morgan(smiles) | |
| # convert it to the indices text format | |
| s = morgan_fingerprint_to_text(mf) | |
| # encode | |
| input_ids = tokenizer.encode(s, return_tensors="pt") | |
| # Generate output sequence | |
| output_ids = model.generate(input_ids, max_length=64, num_beams=5) | |
| # Decode the generated output | |
| output_smiles = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| print(output_smiles) | |