|
|
"""
|
|
|
Inference module for Healthcare Reason Classification
|
|
|
|
|
|
This module provides inference for the reason classification system,
|
|
|
separate from the medical/insurance classifier.
|
|
|
"""
|
|
|
|
|
|
from ..head import ClassifierHead
|
|
|
from datetime import datetime
|
|
|
import os
|
|
|
import pprint
|
|
|
import torch
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
REASON_CATEGORIES = {
|
|
|
0: "ROUTINE_CARE",
|
|
|
1: "PAIN_CONDITIONS",
|
|
|
2: "INJURIES",
|
|
|
3: "SKIN_CONDITIONS",
|
|
|
4: "STRUCTURAL_ISSUES",
|
|
|
5: "PROCEDURES"
|
|
|
}
|
|
|
|
|
|
REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints"
|
|
|
DATETIME_FORMAT = "%Y%m%d_%H%M%S"
|
|
|
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
|
|
|
|
|
|
def get_device():
|
|
|
"""Get the best available device for inference."""
|
|
|
if torch.backends.mps.is_available():
|
|
|
return torch.device("mps")
|
|
|
elif torch.cuda.is_available():
|
|
|
return torch.device("cuda")
|
|
|
else:
|
|
|
return torch.device("cpu")
|
|
|
|
|
|
DEVICE = get_device()
|
|
|
|
|
|
def get_reason_models():
|
|
|
"""Get the embedding model and classifier head for reason inference."""
|
|
|
|
|
|
embedding_model = SentenceTransformer(
|
|
|
MODEL_NAME,
|
|
|
prompts={
|
|
|
'classification': 'task: healthcare reason classification | query: ',
|
|
|
'retrieval (query)': 'task: search result | query: ',
|
|
|
'retrieval (document)': 'title: {title | "none"} | text: ',
|
|
|
},
|
|
|
default_prompt_name='classification',
|
|
|
)
|
|
|
|
|
|
|
|
|
classifier_head = ClassifierHead(len(REASON_CATEGORIES))
|
|
|
|
|
|
return embedding_model.to(DEVICE), classifier_head.to(DEVICE)
|
|
|
|
|
|
def predict_reason_query(
|
|
|
text: list[str],
|
|
|
embedding_model: SentenceTransformer,
|
|
|
classifier_head: ClassifierHead,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Runs the full inference pipeline for reason classification: Text -> Embedding -> Classification.
|
|
|
"""
|
|
|
|
|
|
embedding_model.eval()
|
|
|
classifier_head.eval()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
embeddings = embedding_model.encode(
|
|
|
text,
|
|
|
convert_to_tensor=True,
|
|
|
device=DEVICE
|
|
|
).to(DEVICE)
|
|
|
|
|
|
|
|
|
probabilities = classifier_head.predict_proba(embeddings)
|
|
|
|
|
|
|
|
|
predicted_indices = torch.argmax(probabilities, dim=1)
|
|
|
|
|
|
|
|
|
if predicted_indices.dim() == 0:
|
|
|
predicted_indices = [predicted_indices.item()]
|
|
|
else:
|
|
|
predicted_indices = predicted_indices.cpu().tolist()
|
|
|
|
|
|
|
|
|
confidences = []
|
|
|
for i, idx in enumerate(predicted_indices):
|
|
|
conf = probabilities[i][idx].item() if probabilities.dim() > 1 else probabilities[idx].item()
|
|
|
confidences.append(conf)
|
|
|
|
|
|
|
|
|
predicted_labels = [REASON_CATEGORIES[i] for i in predicted_indices]
|
|
|
|
|
|
return {
|
|
|
'prediction': predicted_labels,
|
|
|
'confidence': confidences,
|
|
|
'probabilities': probabilities.cpu().tolist()
|
|
|
}
|
|
|
|
|
|
def predict_single_reason(query: str) -> dict:
|
|
|
"""Convenience function to predict a single reason query."""
|
|
|
try:
|
|
|
embedding_model, classifier_head = get_reason_models()
|
|
|
|
|
|
|
|
|
if os.path.exists(REASON_CHECKPOINT_PATH):
|
|
|
for d in os.listdir(REASON_CHECKPOINT_PATH):
|
|
|
if d.endswith('.pt'):
|
|
|
checkpoint_path = f"{REASON_CHECKPOINT_PATH}/{d}"
|
|
|
try:
|
|
|
state_dict = torch.load(checkpoint_path, weights_only=True, map_location=DEVICE)
|
|
|
classifier_head.load_state_dict(state_dict)
|
|
|
print(f"Loaded trained weights from {checkpoint_path}")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"Could not load weights from {checkpoint_path}: {e}")
|
|
|
|
|
|
result = predict_reason_query([query], embedding_model, classifier_head)
|
|
|
|
|
|
|
|
|
prediction = result['prediction'][0] if isinstance(result['prediction'], list) else str(result['prediction'])
|
|
|
confidence = result['confidence'] if isinstance(result['confidence'], float) else (result['confidence'][0] if isinstance(result['confidence'], list) else float(result['confidence']))
|
|
|
|
|
|
|
|
|
probabilities = result['probabilities']
|
|
|
if isinstance(probabilities, list) and len(probabilities) > 0:
|
|
|
if isinstance(probabilities[0], list):
|
|
|
probabilities = probabilities[0]
|
|
|
|
|
|
|
|
|
prob_dict = {}
|
|
|
for i, category in REASON_CATEGORIES.items():
|
|
|
if i < len(probabilities):
|
|
|
prob_dict[category] = float(probabilities[i])
|
|
|
else:
|
|
|
prob_dict[category] = 0.0
|
|
|
|
|
|
return {
|
|
|
'query': query,
|
|
|
'category': prediction,
|
|
|
'confidence': confidence,
|
|
|
'probabilities': prob_dict
|
|
|
}
|
|
|
except Exception as e:
|
|
|
|
|
|
return {
|
|
|
'query': query,
|
|
|
'category': 'GENERAL_MEDICAL',
|
|
|
'confidence': 0.5,
|
|
|
'probabilities': {category: 1.0/len(REASON_CATEGORIES) for category in REASON_CATEGORIES.values()},
|
|
|
'error': str(e)
|
|
|
}
|
|
|
|
|
|
def test_reason_classifier():
|
|
|
"""Test the reason classifier with sample queries."""
|
|
|
latest = None
|
|
|
path = ""
|
|
|
|
|
|
|
|
|
if os.path.exists(REASON_CHECKPOINT_PATH):
|
|
|
for d in os.listdir(REASON_CHECKPOINT_PATH):
|
|
|
if d.endswith('.pt'):
|
|
|
checkpoint_path = f"{REASON_CHECKPOINT_PATH}/{d}"
|
|
|
print(f"Found checkpoint: {checkpoint_path}")
|
|
|
path = checkpoint_path
|
|
|
break
|
|
|
|
|
|
if not path:
|
|
|
print("No trained checkpoints found. Using untrained model.")
|
|
|
else:
|
|
|
print("No checkpoint directory found. Using untrained model.")
|
|
|
|
|
|
embedding_model, classifier = get_reason_models()
|
|
|
|
|
|
|
|
|
if path and os.path.exists(path):
|
|
|
try:
|
|
|
state_dict = torch.load(path, weights_only=True, map_location=DEVICE)
|
|
|
classifier.load_state_dict(state_dict)
|
|
|
print(f"Loaded trained weights from {path}")
|
|
|
except Exception as e:
|
|
|
print(f"Could not load weights: {e}. Using untrained model.")
|
|
|
|
|
|
|
|
|
queries = [
|
|
|
"I have heel pain when I walk",
|
|
|
"My toenail is ingrown and painful",
|
|
|
"I need routine foot care",
|
|
|
"I sprained my ankle playing sports",
|
|
|
"I have plantar fasciitis",
|
|
|
"I need a cortisone injection"
|
|
|
]
|
|
|
|
|
|
print("\nTesting reason classification:")
|
|
|
pred = predict_reason_query(
|
|
|
text=queries,
|
|
|
embedding_model=embedding_model,
|
|
|
classifier_head=classifier,
|
|
|
)
|
|
|
|
|
|
pprint.pprint(pred, indent=4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
test_reason_classifier() |