| | import re |
| | from typing import List |
| |
|
| | import click |
| | import torch |
| |
|
| |
|
| | def prepare_results(input_data: List[str], predictions: List[dict]) -> List[List[str]]: |
| | return [ |
| | phla.split(",") + [str(pred["Presentation score"]), pred["Label"]] |
| | for phla, pred in zip(input_data, predictions) |
| | ] |
| |
|
| |
|
| | @click.command( |
| | help="This is a model for prediction the presentation of a peptide on MHC class 1. " |
| | "The model runs on inputs filtered by mhcflurry (and includes the filtering). " |
| | "The model was created as part of Immunology platform of Ardigen. " |
| | "For more information, see https://huggingface.co/ardigen/ardisplay-i" |
| | ) |
| | @click.option( |
| | "--input-file", |
| | "-i", |
| | type=click.Path(exists=True), |
| | required=True, |
| | help="Path to input CSV file. The file should contain one peptide-HLA pair per line," |
| | " with the HLA and peptide separated by a comma (e.g. A02:01,CKTSPLSNWHT)'. " |
| | "Note that file should not contain any headers.", |
| | ) |
| | @click.option( |
| | "--output-file", |
| | "-o", |
| | default="ardisplay_predictions.csv", |
| | type=click.Path(), |
| | help="Path to output CSV file. Default is ardisplay_predictions.csv", |
| | ) |
| | @click.option( |
| | "--batch-size", |
| | "-b", |
| | type=int, |
| | default=100, |
| | help="Batch size for inference. Default is 100.", |
| | ) |
| | @click.option( |
| | "--device", |
| | "-d", |
| | type=str, |
| | default="cpu", |
| | help="Device to use for inference, either 'cpu' or a CUDA device (e.g. 'cuda:0'). Default is 'cpu'.", |
| | ) |
| | def ardisplay_cli(input_file, output_file, batch_size, device): |
| | from datasets import load_dataset |
| | from transformers import pipeline |
| |
|
| | if not re.match("^cpu$|^cuda:\d+$", device): |
| | raise click.BadParameter( |
| | "Device should be 'cpu' or a CUDA device (e.g. 'cuda:0')" |
| | ) |
| |
|
| | device = torch.device(device) |
| | pipe = pipeline( |
| | model="ardigen/ardisplay-i", |
| | trust_remote_code=True, |
| | batch_size=batch_size, |
| | device=device, |
| | ) |
| | dataset = load_dataset("text", data_files={"test": input_file}, split="test") |
| | predictions = pipe(dataset["text"]) |
| |
|
| | results = prepare_results(input_data=dataset["text"], predictions=predictions) |
| | with open(output_file, "w") as f: |
| | f.write("HLA,Peptide,Presentation score,Label\n") |
| | for result in results: |
| | f.write(",".join(result) + "\n") |
| | click.echo(f"Presentation scores calculated, outputting to {output_file}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | ardisplay_cli() |
| |
|