import re from typing import List import click import torch def prepare_results(input_data: List[str], predictions: List[dict]) -> List[List[str]]: return [ phla.split(",") + [str(pred["Presentation score"]), pred["Label"]] for phla, pred in zip(input_data, predictions) ] @click.command( help="This is a model for prediction the presentation of a peptide on MHC class 1. " "The model runs on inputs filtered by mhcflurry (and includes the filtering). " "The model was created as part of Immunology platform of Ardigen. " "For more information, see https://huggingface.co/ardigen/ardisplay-i" ) @click.option( "--input-file", "-i", type=click.Path(exists=True), required=True, help="Path to input CSV file. The file should contain one peptide-HLA pair per line," " with the HLA and peptide separated by a comma (e.g. A02:01,CKTSPLSNWHT)'. " "Note that file should not contain any headers.", ) @click.option( "--output-file", "-o", default="ardisplay_predictions.csv", type=click.Path(), help="Path to output CSV file. Default is ardisplay_predictions.csv", ) @click.option( "--batch-size", "-b", type=int, default=100, help="Batch size for inference. Default is 100.", ) @click.option( "--device", "-d", type=str, default="cpu", help="Device to use for inference, either 'cpu' or a CUDA device (e.g. 'cuda:0'). Default is 'cpu'.", ) def ardisplay_cli(input_file, output_file, batch_size, device): from datasets import load_dataset from transformers import pipeline if not re.match("^cpu$|^cuda:\d+$", device): raise click.BadParameter( "Device should be 'cpu' or a CUDA device (e.g. 'cuda:0')" ) device = torch.device(device) pipe = pipeline( model="ardigen/ardisplay-i", trust_remote_code=True, batch_size=batch_size, device=device, ) dataset = load_dataset("text", data_files={"test": input_file}, split="test") predictions = pipe(dataset["text"]) results = prepare_results(input_data=dataset["text"], predictions=predictions) with open(output_file, "w") as f: f.write("HLA,Peptide,Presentation score,Label\n") for result in results: f.write(",".join(result) + "\n") click.echo(f"Presentation scores calculated, outputting to {output_file}") if __name__ == "__main__": ardisplay_cli()