""" LOINC2SDTM SAPBERT Classification Space Uses trained SAPBERT model for multi-label classification """ import gradio as gr import torch import json from transformers import AutoTokenizer, AutoModel import torch.nn as nn from huggingface_hub import hf_hub_download print("Loading SAPBERT model...") # Configuration MODEL_PATH = "panikos/loinc2sdtm-sapbert-extended-model" # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # Download and load vocabularies from Hub print("Downloading vocabularies.json from Hub...") vocab_file_path = hf_hub_download( repo_id=MODEL_PATH, filename="vocabularies.json", repo_type="model" ) with open(vocab_file_path, "r") as f: vocab_data = json.load(f) vocabularies = vocab_data['vocabularies'] id2label = vocab_data['id2label'] train_fields = vocab_data['train_fields'] print(f"Loaded vocabularies for {len(train_fields)} fields") # Define model architecture (must match training) class LOINC2SDTMClassifier(nn.Module): def __init__(self, base_model, num_classes_dict): super().__init__() self.encoder = base_model self.config = base_model.config self.hidden_size = base_model.config.hidden_size self.classifiers = nn.ModuleDict({ field: nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(self.hidden_size // 2, num_classes) ) for field, num_classes in num_classes_dict.items() }) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) cls_embedding = outputs.last_hidden_state[:, 0, :] logits = { field: classifier(cls_embedding) for field, classifier in self.classifiers.items() } return logits # Load base model and create classifier print("Loading base SAPBERT model...") base_model = AutoModel.from_pretrained(MODEL_PATH) num_classes_dict = {field: len(vocab) for field, vocab in vocabularies.items()} model = LOINC2SDTMClassifier(base_model, num_classes_dict) # Load model weights from safetensors print("Loading trained model weights from safetensors...") from safetensors.torch import load_file model_weights_path = hf_hub_download( repo_id=MODEL_PATH, filename="model.safetensors", repo_type="model" ) state_dict = load_file(model_weights_path) model.load_state_dict(state_dict) model.eval() print("Model loaded successfully!") def predict_loinc_mapping(loinc_code, component="", property_val="", system=""): """ Predict SDTM mappings for a LOINC code """ try: # Create input text (same format as training) input_text = f"{loinc_code} {component} {property_val} {system}".strip() # Tokenize encoding = tokenizer( input_text, padding='max_length', truncation=True, max_length=64, return_tensors='pt' ) # Get predictions with torch.no_grad(): logits = model(encoding['input_ids'], encoding['attention_mask']) # Convert logits to predictions predictions = {} for field in train_fields: pred_idx = torch.argmax(logits[field], dim=1).item() # Convert string keys back to int for id2label lookup pred_label = id2label[field][str(pred_idx)] predictions[field] = pred_label # Format as JSON result = { "loinc_code": loinc_code, "input_context": { "component": component if component else "not provided", "property": property_val if property_val else "not provided", "system": system if system else "not provided" }, "sdtm_mappings": predictions } return json.dumps(result, indent=2) except Exception as e: return json.dumps({ "error": str(e), "loinc_code": loinc_code }, indent=2) # Create Gradio interface with gr.Blocks(title="LOINC to SDTM Mapper (SAPBERT)") as demo: gr.Markdown(""" # LOINC to SDTM LB Mapper Enter a LOINC code to get predicted SDTM LB domain mappings. Optionally provide additional LOINC metadata for better predictions. **Model:** SAPBERT-based multi-label classifier **Training data:** 2,304 FDA-approved LOINC-to-SDTM mappings **Fields predicted:** LBTESTCD, LBTEST, LBSPEC, LBSTRESU, LBMETHOD, LBPTFL, LBRESTYP, LBRESSCL """) with gr.Row(): with gr.Column(): loinc_input = gr.Textbox( label="LOINC Code (required)", placeholder="e.g., 1558-6", value="1558-6" ) component_input = gr.Textbox( label="Component (optional)", placeholder="e.g., Glucose" ) property_input = gr.Textbox( label="Property (optional)", placeholder="e.g., Mass concentration" ) system_input = gr.Textbox( label="System (optional)", placeholder="e.g., Serum" ) submit_btn = gr.Button("Get SDTM Mapping", variant="primary") with gr.Column(): output = gr.Code( label="SDTM Mapping (JSON)", language="json" ) submit_btn.click( fn=predict_loinc_mapping, inputs=[loinc_input, component_input, property_input, system_input], outputs=output ) gr.Examples( examples=[ ["1558-6", "Glucose", "Mass concentration", "Serum"], ["2345-7", "Glucose", "MCnc", "Serum or Plasma"], ["2160-0", "Creatinine", "MCnc", "Serum or Plasma"], ["883-9", "ABO group", "Type", "Blood"], ["718-7", "Hemoglobin", "MCnc", "Blood"], ], inputs=[loinc_input, component_input, property_input, system_input], label="Example LOINC Codes" ) gr.Markdown(""" --- ### About This Model This classifier was trained on FDA-approved LOINC-to-SDTM mappings using SAPBERT (a biomedical entity linking model based on PubMedBERT). **How it works:** 1. Combines LOINC code with metadata (component, property, system) 2. Encodes using SAPBERT (trained on PubMed biomedical text) 3. Predicts 8 SDTM fields simultaneously using separate classifier heads **Expected accuracy:** 95-98% based on training metrics """) if __name__ == "__main__": demo.launch()