| # AI-based D-amino acid substitution for optimizing antimicrobial peptides to treat multidrug-resistant bacterial infection | |
| This repository contains the code for the paper "AI-based D-amino acid substitution for optimizing antimicrobial peptides to treat multidrug-resistant bacterial infection" | |
| ## Requirements | |
| ``` | |
| mamba_ssm==2.2.4 | |
| numpy==1.26.3 | |
| pandas==2.1.4 | |
| rdkit==2024.3.5 | |
| scikit_learn==1.4.1.post1 | |
| scipy==1.13.0 | |
| torch==2.2.0 | |
| torchmetrics==1.3.1 | |
| torchvision==0.17.0 | |
| ``` | |
| You can install them with `pip install -r requirements.txt` | |
| Additionally, `mamba_ssm` is optional since it is not used for our final method. | |
| You can comment `mamba_ssm==2.2.4` in `requirements.txt` and `from mamba_ssm import Mamba` in `network.py` out if you don't want to install it and avoid use `--q-encoder mamba`. | |
| ## Training | |
| There are two .py file for training: `main.py` and `main_simple.py`. | |
| `main.py`: Can train model with Classification and Regression tasks. Prefered with regression task. | |
| `main_simple.py`: Can ONLY train model with Classification task. Prefered with classification task. `simple` means a simple dataset that direct loads pre-processed data. | |
| example: | |
| ``` | |
| python main-simple.py \ | |
| --q-encoder cnn \ # Encoder, can be cnn, lstm, gru, mamba, mha | |
| --channels 16 \ # Encoder channels | |
| --side-enc lstm \ # Side sequence Encoder, only lstm implemented, only use with cnn encoder | |
| --fusion att \ # Fusion method, can be att, mlp or diff | |
| --task cls \ # Task, can be cls or reg | |
| --loss ce \ # Loss, can be ce or mse, some other losses can be found in code | |
| --batch-size 32 \ # Batch size | |
| --epochs 35 \ # Epochs | |
| --gpu 0 \ # GPU index to use, -1 for cpu | |
| # ===CNN only options=== \ | |
| --pcs \ # Enable protease cleavage site dyeing for input pictures | |
| --resize 768 \ # Resize input pictures, can be 1 or 2 numbers like 768 or 768 512 | |
| # ===main_simple.py only options=== \ | |
| --llm-data # Use LLM augmented training data | |
| ``` | |
| Corresponding model weight checkpoints will be saved in the subdirectory of `run-cls` or `run-reg`, e.g. `/run-cls/cnn-att-16-lstm-pcs-simple-llm-768-oneway-ce-32-0.001-35/` | |
| For more arguments, please refer to the code of `main.py` or `main_simple.py` | |
| ## Inference | |
| You can simple replace `main.py` with `infer.py` in your training command to do inference. Remember to add `--simple` if you used checkpoints trained from `main_simple.py` | |
| For case study scanning, please use `infer_case.py` with an additional argument `--case r2` or `--case YOUR_PEPTIDE_SEQUENCE` | |
| Inference results will be saved in the weights directory in `csv` format, e.g. `/run-cls/cnn-att-16-lstm-pcs-simple-llm-768-oneway-ce-32-0.001-35/preds_test.csv` |