panikos's picture
Fix model loading: use model.safetensors instead of pytorch_model.bin
e80655d verified
"""
LOINC2SDTM SAPBERT Classification Space
Uses trained SAPBERT model for multi-label classification
"""
import gradio as gr
import torch
import json
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
from huggingface_hub import hf_hub_download
print("Loading SAPBERT model...")
# Configuration
MODEL_PATH = "panikos/loinc2sdtm-sapbert-extended-model"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# Download and load vocabularies from Hub
print("Downloading vocabularies.json from Hub...")
vocab_file_path = hf_hub_download(
repo_id=MODEL_PATH,
filename="vocabularies.json",
repo_type="model"
)
with open(vocab_file_path, "r") as f:
vocab_data = json.load(f)
vocabularies = vocab_data['vocabularies']
id2label = vocab_data['id2label']
train_fields = vocab_data['train_fields']
print(f"Loaded vocabularies for {len(train_fields)} fields")
# Define model architecture (must match training)
class LOINC2SDTMClassifier(nn.Module):
def __init__(self, base_model, num_classes_dict):
super().__init__()
self.encoder = base_model
self.config = base_model.config
self.hidden_size = base_model.config.hidden_size
self.classifiers = nn.ModuleDict({
field: nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(self.hidden_size // 2, num_classes)
)
for field, num_classes in num_classes_dict.items()
})
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls_embedding = outputs.last_hidden_state[:, 0, :]
logits = {
field: classifier(cls_embedding)
for field, classifier in self.classifiers.items()
}
return logits
# Load base model and create classifier
print("Loading base SAPBERT model...")
base_model = AutoModel.from_pretrained(MODEL_PATH)
num_classes_dict = {field: len(vocab) for field, vocab in vocabularies.items()}
model = LOINC2SDTMClassifier(base_model, num_classes_dict)
# Load model weights from safetensors
print("Loading trained model weights from safetensors...")
from safetensors.torch import load_file
model_weights_path = hf_hub_download(
repo_id=MODEL_PATH,
filename="model.safetensors",
repo_type="model"
)
state_dict = load_file(model_weights_path)
model.load_state_dict(state_dict)
model.eval()
print("Model loaded successfully!")
def predict_loinc_mapping(loinc_code, component="", property_val="", system=""):
"""
Predict SDTM mappings for a LOINC code
"""
try:
# Create input text (same format as training)
input_text = f"{loinc_code} {component} {property_val} {system}".strip()
# Tokenize
encoding = tokenizer(
input_text,
padding='max_length',
truncation=True,
max_length=64,
return_tensors='pt'
)
# Get predictions
with torch.no_grad():
logits = model(encoding['input_ids'], encoding['attention_mask'])
# Convert logits to predictions
predictions = {}
for field in train_fields:
pred_idx = torch.argmax(logits[field], dim=1).item()
# Convert string keys back to int for id2label lookup
pred_label = id2label[field][str(pred_idx)]
predictions[field] = pred_label
# Format as JSON
result = {
"loinc_code": loinc_code,
"input_context": {
"component": component if component else "not provided",
"property": property_val if property_val else "not provided",
"system": system if system else "not provided"
},
"sdtm_mappings": predictions
}
return json.dumps(result, indent=2)
except Exception as e:
return json.dumps({
"error": str(e),
"loinc_code": loinc_code
}, indent=2)
# Create Gradio interface
with gr.Blocks(title="LOINC to SDTM Mapper (SAPBERT)") as demo:
gr.Markdown("""
# LOINC to SDTM LB Mapper
Enter a LOINC code to get predicted SDTM LB domain mappings.
Optionally provide additional LOINC metadata for better predictions.
**Model:** SAPBERT-based multi-label classifier
**Training data:** 2,304 FDA-approved LOINC-to-SDTM mappings
**Fields predicted:** LBTESTCD, LBTEST, LBSPEC, LBSTRESU, LBMETHOD, LBPTFL, LBRESTYP, LBRESSCL
""")
with gr.Row():
with gr.Column():
loinc_input = gr.Textbox(
label="LOINC Code (required)",
placeholder="e.g., 1558-6",
value="1558-6"
)
component_input = gr.Textbox(
label="Component (optional)",
placeholder="e.g., Glucose"
)
property_input = gr.Textbox(
label="Property (optional)",
placeholder="e.g., Mass concentration"
)
system_input = gr.Textbox(
label="System (optional)",
placeholder="e.g., Serum"
)
submit_btn = gr.Button("Get SDTM Mapping", variant="primary")
with gr.Column():
output = gr.Code(
label="SDTM Mapping (JSON)",
language="json"
)
submit_btn.click(
fn=predict_loinc_mapping,
inputs=[loinc_input, component_input, property_input, system_input],
outputs=output
)
gr.Examples(
examples=[
["1558-6", "Glucose", "Mass concentration", "Serum"],
["2345-7", "Glucose", "MCnc", "Serum or Plasma"],
["2160-0", "Creatinine", "MCnc", "Serum or Plasma"],
["883-9", "ABO group", "Type", "Blood"],
["718-7", "Hemoglobin", "MCnc", "Blood"],
],
inputs=[loinc_input, component_input, property_input, system_input],
label="Example LOINC Codes"
)
gr.Markdown("""
---
### About This Model
This classifier was trained on FDA-approved LOINC-to-SDTM mappings using SAPBERT
(a biomedical entity linking model based on PubMedBERT).
**How it works:**
1. Combines LOINC code with metadata (component, property, system)
2. Encodes using SAPBERT (trained on PubMed biomedical text)
3. Predicts 8 SDTM fields simultaneously using separate classifier heads
**Expected accuracy:** 95-98% based on training metrics
""")
if __name__ == "__main__":
demo.launch()