File size: 3,385 Bytes
9d36a4d
fed6436
9d36a4d
fed6436
f9c9b95
fed6436
9d36a4d
 
fed6436
9d36a4d
 
 
 
fed6436
9d36a4d
 
fed6436
 
9d36a4d
 
fed6436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d36a4d
f9c9b95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# sdg_predict/inference.py
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
import logging
import json


def load_model(model_name, device):
    tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
    model.eval()
    return tokenizer, model


def batched(iterable, batch_size):
    for i in range(0, len(iterable), batch_size):
        yield iterable[i : i + batch_size]


def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=True):
    classifier = pipeline(
        "text-classification",
        model=model,
        tokenizer=tokenizer,
        device=device,
        batch_size=batch_size,
        truncation=True,
        padding=True,
        max_length=512,
        top_k=None if return_all_scores else 1,
    )

    results = classifier(texts)
    if return_all_scores:
        for result in results:
            for score in result:
                score["score"] = round(
                    score["score"], 3
                )  # Round scores to 3 decimal places
    else:
        for result in results:
            result["score"] = round(
                result["score"], 3
            )  # Round top score to 3 decimal places

    return results


def binary_from_softmax(prediction, cap_class0=0.5):
    score_0 = next((x["score"] for x in prediction if x["label"] == "0"), 0.0)
    score_0 = min(score_0, cap_class0)

    binary_predictions = {
        label: 0.0 for label in map(str, range(1, 18))
    }  # Initialize all labels to 0.0

    for entry in prediction:
        label = entry["label"]
        if label == "0":
            continue
        score = entry["score"]
        binary_score = score / (score + score_0) if (score + score_0) > 0 else 0.0
        binary_predictions[label] = round(binary_score, 3)

    return binary_predictions


def setup_device():
    logging.info("Setting up device")
    if torch.backends.mps.is_available():
        logging.info("Using MPS device")
        return torch.device("mps")
    elif torch.cuda.is_available():
        logging.info("Using CUDA device")
        return torch.device("cuda")
    else:
        logging.info("Using CPU device")
        return torch.device("cpu")


def load_model_and_tokenizer(model_name, device):
    logging.info("Loading model: %s", model_name)
    tokenizer, model = load_model(model_name, device)
    logging.info("Model loaded successfully")
    return tokenizer, model


def load_input_data(input, key):
    logging.info("Loading input data from %s", input)
    texts = []
    rows = []
    with input.open() as f:
        for line in f:
            row = json.loads(line)
            if key not in row:
                continue
            texts.append(row[key])
            logging.debug("Text: %s", row[key])
            rows.append(row)
    logging.info("Loaded %d rows of input data", len(rows))
    return texts, rows


def perform_predictions(texts, tokenizer, model, device, batch_size, top1):
    logging.info("Starting predictions on %d texts", len(texts))
    predictions = predict(
        texts,
        tokenizer,
        model,
        device,
        batch_size=batch_size,
        return_all_scores=not top1,
    )
    logging.info("Predictions completed")
    return predictions