File size: 2,461 Bytes
20901a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import re
from typing import List

import click
import torch


def prepare_results(input_data: List[str], predictions: List[dict]) -> List[List[str]]:
    return [
        phla.split(",") + [str(pred["Presentation score"]), pred["Label"]]
        for phla, pred in zip(input_data, predictions)
    ]


@click.command(
    help="This is a model for prediction the presentation of a peptide on MHC class 1. "
    "The model runs on inputs filtered by mhcflurry (and includes the filtering). "
    "The model was created as part of Immunology platform of Ardigen. "
    "For more information, see https://huggingface.co/ardigen/ardisplay-i"
)
@click.option(
    "--input-file",
    "-i",
    type=click.Path(exists=True),
    required=True,
    help="Path to input CSV file. The file should contain one peptide-HLA pair per line,"
    " with the HLA and peptide separated by a comma (e.g. A02:01,CKTSPLSNWHT)'. "
    "Note that file should not contain any headers.",
)
@click.option(
    "--output-file",
    "-o",
    default="ardisplay_predictions.csv",
    type=click.Path(),
    help="Path to output CSV file. Default is ardisplay_predictions.csv",
)
@click.option(
    "--batch-size",
    "-b",
    type=int,
    default=100,
    help="Batch size for inference. Default is 100.",
)
@click.option(
    "--device",
    "-d",
    type=str,
    default="cpu",
    help="Device to use for inference, either 'cpu' or a CUDA device (e.g. 'cuda:0'). Default is 'cpu'.",
)
def ardisplay_cli(input_file, output_file, batch_size, device):
    from datasets import load_dataset
    from transformers import pipeline

    if not re.match("^cpu$|^cuda:\d+$", device):
        raise click.BadParameter(
            "Device should be 'cpu' or a CUDA device (e.g. 'cuda:0')"
        )

    device = torch.device(device)
    pipe = pipeline(
        model="ardigen/ardisplay-i",
        trust_remote_code=True,
        batch_size=batch_size,
        device=device,
    )
    dataset = load_dataset("text", data_files={"test": input_file}, split="test")
    predictions = pipe(dataset["text"])

    results = prepare_results(input_data=dataset["text"], predictions=predictions)
    with open(output_file, "w") as f:
        f.write("HLA,Peptide,Presentation score,Label\n")
        for result in results:
            f.write(",".join(result) + "\n")
    click.echo(f"Presentation scores calculated, outputting to {output_file}")


if __name__ == "__main__":
    ardisplay_cli()