File size: 3,178 Bytes
b7f3196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from classifier.head import ClassifierHead
from classifier.utils import CATEGORIES, CHECKPOINT_PATH, DEVICE, get_models, CLASSIFIER_NAME, get_latest_checkpoint

import argparse
import pprint
import torch
from sentence_transformers import SentenceTransformer

def classifier_init(checkpoint_path: str | None = None, model_id: str | None = CLASSIFIER_NAME) -> (SentenceTransformer, ClassifierHead):
    if checkpoint_path:
        latest_checkpoint = get_latest_checkpoint(checkpoint_path)
        print(f"Loading checkpoint from {latest_checkpoint}")
        embedding_model, classifier = get_models(model_id=latest_checkpoint)
    else:
        embedding_model, classifier = get_models(model_id=model_id)

    return embedding_model, classifier

def predict_query(

    text: list[str],

    embedding_model: SentenceTransformer,

    classifier_head: ClassifierHead,

) -> dict:
    """

    Runs the full inference pipeline: Text -> Embedding -> Classification.

    """
    # Set models to evaluation mode
    embedding_model.eval()
    classifier_head.eval()

    with torch.no_grad():
        # Embed the text
        embeddings = embedding_model.encode(
            text,
            convert_to_tensor=True,
            device=DEVICE
        ).to(DEVICE)

        # Calculate probabilities and prediction
        probabilities = classifier_head.predict_proba(embeddings)

        # Get the predicted index and confidence
        predicted_indices = torch.argmax(probabilities, dim=1).unsqueeze(1)
        confidences = torch.gather(probabilities, dim=1, index=predicted_indices).squeeze().tolist()

        # Get the predicted label name
        predicted_labels = [CATEGORIES[i] for i in predicted_indices]

    return {
        'prediction': predicted_labels,
        'confidence': confidences,
        'probabilities': probabilities.cpu().squeeze().tolist()
    }

def test(local: bool = False):
    embedding_model, classifier = classifier_init(checkpoint_path=CHECKPOINT_PATH if local else None)

    queries = [
        "Hi! I'm having a really bad rash on my hands. I'm pretty sure it's my excema flairing up. Is there anythign stronger than aquaphor I can use on it?",
        "Hey is there any way I can get an appointment in the next month?",
        "Hey is there any way I can get an appointment in the next month with a doctor?",
        "I'm traveling to South America soon. Do I need to get any vaccines before I go?",
        "I have this rash that popped up today.",
        "How can I make this hosptial bill go away?",
        "I'm so confused do I have to cover the full cost of this operation?",
    ]

    pred = predict_query(
        text=queries,
        embedding_model=embedding_model,
        classifier_head=classifier,
    )

    pprint.pprint(pred, indent=4)

if __name__ == "__main__":
    ap = argparse.ArgumentParser(
        description="Inference on a classifier for triaging health queries"
    )
    ap.add_argument(
        "--local", action="store_true",
        help="Use local checkpoint"
    )
    args = ap.parse_args()

    test(local=args.local)