File size: 2,604 Bytes
b7f3196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""

Utilities for Healthcare Classification System



This module contains shared constants and utilities for the healthcare

classification system.

"""

from classifier.head import ClassifierHead

from classifier.config import load_env

import os
from sentence_transformers import SentenceTransformer
import torch
from datetime import datetime
from pathlib import Path

# Load environment variables (including HF_TOKEN)
load_env()

MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
CLASSIFIER_NAME = "davidgray/health-query-triage"
CATEGORIES: list[str] = ["medical", "insurance"]

# Model and training configuration
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
CHECKPOINT_PATH = "classifier/checkpoints"
DATETIME_FORMAT = "%Y%m%d_%H%M%S"

# Device configuration - use David's newer approach with fallback
try:
    DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
except AttributeError:
    # Fallback for older PyTorch versions
    if torch.backends.mps.is_available():
        DEVICE = torch.device("mps")
    elif torch.cuda.is_available():
        DEVICE = torch.device("cuda")
    else:
        DEVICE = torch.device("cpu")

print(f"Using {DEVICE} device")

def get_models(model_id: str | None = None, num_labels: int = len(CATEGORIES)) -> tuple[SentenceTransformer, ClassifierHead]:
    """

    Loads embeddinggemma-300m-medical model and initializes the classification head.

    

    Returns:

        tuple: (embedding_model, classifier_head)

    """
    try:
        model_body = SentenceTransformer(
            MODEL_NAME,
            prompts={
                'classification': 'task: classification | query: ',
                'retrieval (query)': 'task: search result | query: ',
                'retrieval (document)': 'title: {title | "none"} | text: ',
            },
            default_prompt_name='classification',
        )

        if model_id:
            model_head = ClassifierHead.from_pretrained(model_id)
        else:
            model_head = ClassifierHead(num_labels)

    except Exception as e:
        print(f"Error loading model {MODEL_NAME}: {e}")
        print("Please ensure you have an internet connection and the transformers library installed.")
        raise RuntimeError("Failed to load the embedding model.")
    
    return model_body.to(DEVICE), model_head.to(DEVICE)

def get_latest_checkpoint(checkpoint_path: str):
    return os.path.join(checkpoint_path, sorted(os.listdir(checkpoint_path))[-1])