ardisplay-i / cli /ardisplay_cli.py
bartlomiejkrol-jozaga's picture
release v1.0.0
20901a8
import re
from typing import List
import click
import torch
def prepare_results(input_data: List[str], predictions: List[dict]) -> List[List[str]]:
return [
phla.split(",") + [str(pred["Presentation score"]), pred["Label"]]
for phla, pred in zip(input_data, predictions)
]
@click.command(
help="This is a model for prediction the presentation of a peptide on MHC class 1. "
"The model runs on inputs filtered by mhcflurry (and includes the filtering). "
"The model was created as part of Immunology platform of Ardigen. "
"For more information, see https://huggingface.co/ardigen/ardisplay-i"
)
@click.option(
"--input-file",
"-i",
type=click.Path(exists=True),
required=True,
help="Path to input CSV file. The file should contain one peptide-HLA pair per line,"
" with the HLA and peptide separated by a comma (e.g. A02:01,CKTSPLSNWHT)'. "
"Note that file should not contain any headers.",
)
@click.option(
"--output-file",
"-o",
default="ardisplay_predictions.csv",
type=click.Path(),
help="Path to output CSV file. Default is ardisplay_predictions.csv",
)
@click.option(
"--batch-size",
"-b",
type=int,
default=100,
help="Batch size for inference. Default is 100.",
)
@click.option(
"--device",
"-d",
type=str,
default="cpu",
help="Device to use for inference, either 'cpu' or a CUDA device (e.g. 'cuda:0'). Default is 'cpu'.",
)
def ardisplay_cli(input_file, output_file, batch_size, device):
from datasets import load_dataset
from transformers import pipeline
if not re.match("^cpu$|^cuda:\d+$", device):
raise click.BadParameter(
"Device should be 'cpu' or a CUDA device (e.g. 'cuda:0')"
)
device = torch.device(device)
pipe = pipeline(
model="ardigen/ardisplay-i",
trust_remote_code=True,
batch_size=batch_size,
device=device,
)
dataset = load_dataset("text", data_files={"test": input_file}, split="test")
predictions = pipe(dataset["text"])
results = prepare_results(input_data=dataset["text"], predictions=predictions)
with open(output_file, "w") as f:
f.write("HLA,Peptide,Presentation score,Label\n")
for result in results:
f.write(",".join(result) + "\n")
click.echo(f"Presentation scores calculated, outputting to {output_file}")
if __name__ == "__main__":
ardisplay_cli()