|
|
from classifier.head import ClassifierHead
|
|
|
from classifier.utils import CATEGORIES, CHECKPOINT_PATH, DEVICE, get_models, CLASSIFIER_NAME, get_latest_checkpoint
|
|
|
|
|
|
import argparse
|
|
|
import pprint
|
|
|
import torch
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
def classifier_init(checkpoint_path: str | None = None, model_id: str | None = CLASSIFIER_NAME) -> (SentenceTransformer, ClassifierHead):
|
|
|
if checkpoint_path:
|
|
|
latest_checkpoint = get_latest_checkpoint(checkpoint_path)
|
|
|
print(f"Loading checkpoint from {latest_checkpoint}")
|
|
|
embedding_model, classifier = get_models(model_id=latest_checkpoint)
|
|
|
else:
|
|
|
embedding_model, classifier = get_models(model_id=model_id)
|
|
|
|
|
|
return embedding_model, classifier
|
|
|
|
|
|
def predict_query(
|
|
|
text: list[str],
|
|
|
embedding_model: SentenceTransformer,
|
|
|
classifier_head: ClassifierHead,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Runs the full inference pipeline: Text -> Embedding -> Classification.
|
|
|
"""
|
|
|
|
|
|
embedding_model.eval()
|
|
|
classifier_head.eval()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
embeddings = embedding_model.encode(
|
|
|
text,
|
|
|
convert_to_tensor=True,
|
|
|
device=DEVICE
|
|
|
).to(DEVICE)
|
|
|
|
|
|
|
|
|
probabilities = classifier_head.predict_proba(embeddings)
|
|
|
|
|
|
|
|
|
predicted_indices = torch.argmax(probabilities, dim=1).unsqueeze(1)
|
|
|
confidences = torch.gather(probabilities, dim=1, index=predicted_indices).squeeze().tolist()
|
|
|
|
|
|
|
|
|
predicted_labels = [CATEGORIES[i] for i in predicted_indices]
|
|
|
|
|
|
return {
|
|
|
'prediction': predicted_labels,
|
|
|
'confidence': confidences,
|
|
|
'probabilities': probabilities.cpu().squeeze().tolist()
|
|
|
}
|
|
|
|
|
|
def test(local: bool = False):
|
|
|
embedding_model, classifier = classifier_init(checkpoint_path=CHECKPOINT_PATH if local else None)
|
|
|
|
|
|
queries = [
|
|
|
"Hi! I'm having a really bad rash on my hands. I'm pretty sure it's my excema flairing up. Is there anythign stronger than aquaphor I can use on it?",
|
|
|
"Hey is there any way I can get an appointment in the next month?",
|
|
|
"Hey is there any way I can get an appointment in the next month with a doctor?",
|
|
|
"I'm traveling to South America soon. Do I need to get any vaccines before I go?",
|
|
|
"I have this rash that popped up today.",
|
|
|
"How can I make this hosptial bill go away?",
|
|
|
"I'm so confused do I have to cover the full cost of this operation?",
|
|
|
]
|
|
|
|
|
|
pred = predict_query(
|
|
|
text=queries,
|
|
|
embedding_model=embedding_model,
|
|
|
classifier_head=classifier,
|
|
|
)
|
|
|
|
|
|
pprint.pprint(pred, indent=4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
ap = argparse.ArgumentParser(
|
|
|
description="Inference on a classifier for triaging health queries"
|
|
|
)
|
|
|
ap.add_argument(
|
|
|
"--local", action="store_true",
|
|
|
help="Use local checkpoint"
|
|
|
)
|
|
|
args = ap.parse_args()
|
|
|
|
|
|
test(local=args.local)
|
|
|
|