|
|
--- |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- bleu |
|
|
- rouge |
|
|
- meteor |
|
|
- exact_match |
|
|
base_model: |
|
|
- QizhiPei/biot5-plus-base |
|
|
pipeline_tag: text-generation |
|
|
library_name: transformers |
|
|
--- |
|
|
# Model Card for ChemAligner-T5 |
|
|
|
|
|
|
|
|
## How to Get Started with the Model |
|
|
|
|
|
Below is an example of how to load and generate outputs with this model: |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import transformers |
|
|
from huggingface_hub import login |
|
|
from transformers import AutoTokenizer |
|
|
from transformers.models.t5 import T5ForConditionalGeneration |
|
|
import torch |
|
|
|
|
|
login('<your_hf_token>') |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Neeze/ChemAligner-T5") |
|
|
model = T5ForConditionalGeneration.from_pretrained("Neeze/ChemAligner-T5").to(device) |
|
|
|
|
|
sample_caption = ( |
|
|
"The molecule is a energy storage and a fat storage, which impacts cardiovascular " |
|
|
"disease, cancer, and metabolic syndrome, and is characterized as thyroxine treatment. " |
|
|
"The molecule is a membrane stabilizer and inflammatory, and it impacts pancreatitis. " |
|
|
"The molecule is a energy source and a nutrient, impacting both obesity and atherosclerosis." |
|
|
) |
|
|
|
|
|
task_definition = ( |
|
|
"Definition: You are given a molecule description in English. " |
|
|
"Your job is to generate the corresponding molecule in SELFIES representation.\n\n" |
|
|
) |
|
|
|
|
|
task_input = ( |
|
|
f"{task_definition}" |
|
|
f"Now complete the following example -\n" |
|
|
f"Input: {sample_caption}\nOutput: " |
|
|
) |
|
|
|
|
|
inputs = tokenizer( |
|
|
task_input, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=512, |
|
|
num_beams=1, |
|
|
do_sample=False, |
|
|
temperature=1.0, |
|
|
decoder_start_token_id=0, |
|
|
eos_token_id=1, |
|
|
pad_token_id=0 |
|
|
) |
|
|
|
|
|
outputs = [ |
|
|
s.replace("<unk>", "").replace("<pad>", "").replace("</s>", "").strip() |
|
|
for s in tokenizer.batch_decode(outputs) |
|
|
] |
|
|
|
|
|
print(*outputs) |
|
|
|
|
|
# [C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][=Branch1][C][=O][O][C][C@@H1][Branch2][Ring1][=Branch2][C][O][C][=Branch1][C][=O][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][O][C][=Branch1][C][=O][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][C][Branch1][C][C][C][C] |
|
|
``` |
|
|
|