Spaces:
Sleeping
Sleeping
| from classifier.head import ClassifierHead | |
| from classifier.utils import CATEGORIES, CHECKPOINT_PATH, DEVICE, get_models, CLASSIFIER_NAME, get_latest_checkpoint | |
| import argparse | |
| import pprint | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| def classifier_init(checkpoint_path: str | None = None, model_id: str | None = CLASSIFIER_NAME) -> (SentenceTransformer, ClassifierHead): | |
| if checkpoint_path: | |
| latest_checkpoint = get_latest_checkpoint(checkpoint_path) | |
| print(f"Loading checkpoint from {latest_checkpoint}") | |
| embedding_model, classifier = get_models(model_id=latest_checkpoint) | |
| else: | |
| embedding_model, classifier = get_models(model_id=model_id) | |
| return embedding_model, classifier | |
| def predict_query( | |
| text: list[str], | |
| embedding_model: SentenceTransformer, | |
| classifier_head: ClassifierHead, | |
| ) -> dict: | |
| """ | |
| Runs the full inference pipeline: Text -> Embedding -> Classification. | |
| """ | |
| # Set models to evaluation mode | |
| embedding_model.eval() | |
| classifier_head.eval() | |
| with torch.no_grad(): | |
| # Embed the text | |
| embeddings = embedding_model.encode( | |
| text, | |
| convert_to_tensor=True, | |
| device=DEVICE | |
| ).to(DEVICE) | |
| # Calculate probabilities and prediction | |
| probabilities = classifier_head.predict_proba(embeddings) | |
| # Get the predicted index and confidence | |
| predicted_indices = torch.argmax(probabilities, dim=1).unsqueeze(1) | |
| confidences = torch.gather(probabilities, dim=1, index=predicted_indices).squeeze().tolist() | |
| # Get the predicted label name | |
| predicted_labels = [CATEGORIES[i] for i in predicted_indices] | |
| return { | |
| 'prediction': predicted_labels, | |
| 'confidence': confidences, | |
| 'probabilities': probabilities.cpu().squeeze().tolist() | |
| } | |
| def test(local: bool = False): | |
| embedding_model, classifier = classifier_init(checkpoint_path=CHECKPOINT_PATH if local else None) | |
| queries = [ | |
| "Hi! I'm having a really bad rash on my hands. I'm pretty sure it's my excema flairing up. Is there anythign stronger than aquaphor I can use on it?", | |
| "Hey is there any way I can get an appointment in the next month?", | |
| "Hey is there any way I can get an appointment in the next month with a doctor?", | |
| "I'm traveling to South America soon. Do I need to get any vaccines before I go?", | |
| "I have this rash that popped up today.", | |
| "How can I make this hosptial bill go away?", | |
| "I'm so confused do I have to cover the full cost of this operation?", | |
| ] | |
| pred = predict_query( | |
| text=queries, | |
| embedding_model=embedding_model, | |
| classifier_head=classifier, | |
| ) | |
| pprint.pprint(pred, indent=4) | |
| if __name__ == "__main__": | |
| ap = argparse.ArgumentParser( | |
| description="Inference on a classifier for triaging health queries" | |
| ) | |
| ap.add_argument( | |
| "--local", action="store_true", | |
| help="Use local checkpoint" | |
| ) | |
| args = ap.parse_args() | |
| test(local=args.local) | |