|
|
import re |
|
|
from typing import List |
|
|
|
|
|
import click |
|
|
import torch |
|
|
|
|
|
|
|
|
def prepare_results(input_data: List[str], predictions: List[dict]) -> List[List[str]]: |
|
|
return [ |
|
|
phla.split(",") + [str(pred["Presentation score"]), pred["Label"]] |
|
|
for phla, pred in zip(input_data, predictions) |
|
|
] |
|
|
|
|
|
|
|
|
@click.command( |
|
|
help="This is a model for prediction the presentation of a peptide on MHC class 1. " |
|
|
"The model runs on inputs filtered by mhcflurry (and includes the filtering). " |
|
|
"The model was created as part of Immunology platform of Ardigen. " |
|
|
"For more information, see https://huggingface.co/ardigen/ardisplay-i" |
|
|
) |
|
|
@click.option( |
|
|
"--input-file", |
|
|
"-i", |
|
|
type=click.Path(exists=True), |
|
|
required=True, |
|
|
help="Path to input CSV file. The file should contain one peptide-HLA pair per line," |
|
|
" with the HLA and peptide separated by a comma (e.g. A02:01,CKTSPLSNWHT)'. " |
|
|
"Note that file should not contain any headers.", |
|
|
) |
|
|
@click.option( |
|
|
"--output-file", |
|
|
"-o", |
|
|
default="ardisplay_predictions.csv", |
|
|
type=click.Path(), |
|
|
help="Path to output CSV file. Default is ardisplay_predictions.csv", |
|
|
) |
|
|
@click.option( |
|
|
"--batch-size", |
|
|
"-b", |
|
|
type=int, |
|
|
default=100, |
|
|
help="Batch size for inference. Default is 100.", |
|
|
) |
|
|
@click.option( |
|
|
"--device", |
|
|
"-d", |
|
|
type=str, |
|
|
default="cpu", |
|
|
help="Device to use for inference, either 'cpu' or a CUDA device (e.g. 'cuda:0'). Default is 'cpu'.", |
|
|
) |
|
|
def ardisplay_cli(input_file, output_file, batch_size, device): |
|
|
from datasets import load_dataset |
|
|
from transformers import pipeline |
|
|
|
|
|
if not re.match("^cpu$|^cuda:\d+$", device): |
|
|
raise click.BadParameter( |
|
|
"Device should be 'cpu' or a CUDA device (e.g. 'cuda:0')" |
|
|
) |
|
|
|
|
|
device = torch.device(device) |
|
|
pipe = pipeline( |
|
|
model="ardigen/ardisplay-i", |
|
|
trust_remote_code=True, |
|
|
batch_size=batch_size, |
|
|
device=device, |
|
|
) |
|
|
dataset = load_dataset("text", data_files={"test": input_file}, split="test") |
|
|
predictions = pipe(dataset["text"]) |
|
|
|
|
|
results = prepare_results(input_data=dataset["text"], predictions=predictions) |
|
|
with open(output_file, "w") as f: |
|
|
f.write("HLA,Peptide,Presentation score,Label\n") |
|
|
for result in results: |
|
|
f.write(",".join(result) + "\n") |
|
|
click.echo(f"Presentation scores calculated, outputting to {output_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
ardisplay_cli() |
|
|
|