Spaces:
Sleeping
Sleeping
File size: 2,604 Bytes
b7f3196 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
"""
Utilities for Healthcare Classification System
This module contains shared constants and utilities for the healthcare
classification system.
"""
from classifier.head import ClassifierHead
from classifier.config import load_env
import os
from sentence_transformers import SentenceTransformer
import torch
from datetime import datetime
from pathlib import Path
# Load environment variables (including HF_TOKEN)
load_env()
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
CLASSIFIER_NAME = "davidgray/health-query-triage"
CATEGORIES: list[str] = ["medical", "insurance"]
# Model and training configuration
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
CHECKPOINT_PATH = "classifier/checkpoints"
DATETIME_FORMAT = "%Y%m%d_%H%M%S"
# Device configuration - use David's newer approach with fallback
try:
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
except AttributeError:
# Fallback for older PyTorch versions
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
print(f"Using {DEVICE} device")
def get_models(model_id: str | None = None, num_labels: int = len(CATEGORIES)) -> tuple[SentenceTransformer, ClassifierHead]:
"""
Loads embeddinggemma-300m-medical model and initializes the classification head.
Returns:
tuple: (embedding_model, classifier_head)
"""
try:
model_body = SentenceTransformer(
MODEL_NAME,
prompts={
'classification': 'task: classification | query: ',
'retrieval (query)': 'task: search result | query: ',
'retrieval (document)': 'title: {title | "none"} | text: ',
},
default_prompt_name='classification',
)
if model_id:
model_head = ClassifierHead.from_pretrained(model_id)
else:
model_head = ClassifierHead(num_labels)
except Exception as e:
print(f"Error loading model {MODEL_NAME}: {e}")
print("Please ensure you have an internet connection and the transformers library installed.")
raise RuntimeError("Failed to load the embedding model.")
return model_body.to(DEVICE), model_head.to(DEVICE)
def get_latest_checkpoint(checkpoint_path: str):
return os.path.join(checkpoint_path, sorted(os.listdir(checkpoint_path))[-1])
|