""" Utilities for Healthcare Classification System This module contains shared constants and utilities for the healthcare classification system. """ from classifier.head import ClassifierHead from classifier.config import load_env import os from sentence_transformers import SentenceTransformer import torch from datetime import datetime from pathlib import Path # Load environment variables (including HF_TOKEN) load_env() MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical" CLASSIFIER_NAME = "davidgray/health-query-triage" CATEGORIES: list[str] = ["medical", "insurance"] # Model and training configuration MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical" CHECKPOINT_PATH = "classifier/checkpoints" DATETIME_FORMAT = "%Y%m%d_%H%M%S" # Device configuration - use David's newer approach with fallback try: DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" except AttributeError: # Fallback for older PyTorch versions if torch.backends.mps.is_available(): DEVICE = torch.device("mps") elif torch.cuda.is_available(): DEVICE = torch.device("cuda") else: DEVICE = torch.device("cpu") print(f"Using {DEVICE} device") def get_models(model_id: str | None = None, num_labels: int = len(CATEGORIES)) -> tuple[SentenceTransformer, ClassifierHead]: """ Loads embeddinggemma-300m-medical model and initializes the classification head. Returns: tuple: (embedding_model, classifier_head) """ try: model_body = SentenceTransformer( MODEL_NAME, prompts={ 'classification': 'task: classification | query: ', 'retrieval (query)': 'task: search result | query: ', 'retrieval (document)': 'title: {title | "none"} | text: ', }, default_prompt_name='classification', ) if model_id: model_head = ClassifierHead.from_pretrained(model_id) else: model_head = ClassifierHead(num_labels) except Exception as e: print(f"Error loading model {MODEL_NAME}: {e}") print("Please ensure you have an internet connection and the transformers library installed.") raise RuntimeError("Failed to load the embedding model.") return model_body.to(DEVICE), model_head.to(DEVICE) def get_latest_checkpoint(checkpoint_path: str): return os.path.join(checkpoint_path, sorted(os.listdir(checkpoint_path))[-1])