Spaces:
Sleeping
Sleeping
| """ | |
| LOINC2SDTM SAPBERT Classification Space | |
| Uses trained SAPBERT model for multi-label classification | |
| """ | |
| import gradio as gr | |
| import torch | |
| import json | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| print("Loading SAPBERT model...") | |
| # Configuration | |
| MODEL_PATH = "panikos/loinc2sdtm-sapbert-extended-model" | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| # Download and load vocabularies from Hub | |
| print("Downloading vocabularies.json from Hub...") | |
| vocab_file_path = hf_hub_download( | |
| repo_id=MODEL_PATH, | |
| filename="vocabularies.json", | |
| repo_type="model" | |
| ) | |
| with open(vocab_file_path, "r") as f: | |
| vocab_data = json.load(f) | |
| vocabularies = vocab_data['vocabularies'] | |
| id2label = vocab_data['id2label'] | |
| train_fields = vocab_data['train_fields'] | |
| print(f"Loaded vocabularies for {len(train_fields)} fields") | |
| # Define model architecture (must match training) | |
| class LOINC2SDTMClassifier(nn.Module): | |
| def __init__(self, base_model, num_classes_dict): | |
| super().__init__() | |
| self.encoder = base_model | |
| self.config = base_model.config | |
| self.hidden_size = base_model.config.hidden_size | |
| self.classifiers = nn.ModuleDict({ | |
| field: nn.Sequential( | |
| nn.Linear(self.hidden_size, self.hidden_size // 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(self.hidden_size // 2, num_classes) | |
| ) | |
| for field, num_classes in num_classes_dict.items() | |
| }) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| cls_embedding = outputs.last_hidden_state[:, 0, :] | |
| logits = { | |
| field: classifier(cls_embedding) | |
| for field, classifier in self.classifiers.items() | |
| } | |
| return logits | |
| # Load base model and create classifier | |
| print("Loading base SAPBERT model...") | |
| base_model = AutoModel.from_pretrained(MODEL_PATH) | |
| num_classes_dict = {field: len(vocab) for field, vocab in vocabularies.items()} | |
| model = LOINC2SDTMClassifier(base_model, num_classes_dict) | |
| # Load model weights from safetensors | |
| print("Loading trained model weights from safetensors...") | |
| from safetensors.torch import load_file | |
| model_weights_path = hf_hub_download( | |
| repo_id=MODEL_PATH, | |
| filename="model.safetensors", | |
| repo_type="model" | |
| ) | |
| state_dict = load_file(model_weights_path) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| def predict_loinc_mapping(loinc_code, component="", property_val="", system=""): | |
| """ | |
| Predict SDTM mappings for a LOINC code | |
| """ | |
| try: | |
| # Create input text (same format as training) | |
| input_text = f"{loinc_code} {component} {property_val} {system}".strip() | |
| # Tokenize | |
| encoding = tokenizer( | |
| input_text, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=64, | |
| return_tensors='pt' | |
| ) | |
| # Get predictions | |
| with torch.no_grad(): | |
| logits = model(encoding['input_ids'], encoding['attention_mask']) | |
| # Convert logits to predictions | |
| predictions = {} | |
| for field in train_fields: | |
| pred_idx = torch.argmax(logits[field], dim=1).item() | |
| # Convert string keys back to int for id2label lookup | |
| pred_label = id2label[field][str(pred_idx)] | |
| predictions[field] = pred_label | |
| # Format as JSON | |
| result = { | |
| "loinc_code": loinc_code, | |
| "input_context": { | |
| "component": component if component else "not provided", | |
| "property": property_val if property_val else "not provided", | |
| "system": system if system else "not provided" | |
| }, | |
| "sdtm_mappings": predictions | |
| } | |
| return json.dumps(result, indent=2) | |
| except Exception as e: | |
| return json.dumps({ | |
| "error": str(e), | |
| "loinc_code": loinc_code | |
| }, indent=2) | |
| # Create Gradio interface | |
| with gr.Blocks(title="LOINC to SDTM Mapper (SAPBERT)") as demo: | |
| gr.Markdown(""" | |
| # LOINC to SDTM LB Mapper | |
| Enter a LOINC code to get predicted SDTM LB domain mappings. | |
| Optionally provide additional LOINC metadata for better predictions. | |
| **Model:** SAPBERT-based multi-label classifier | |
| **Training data:** 2,304 FDA-approved LOINC-to-SDTM mappings | |
| **Fields predicted:** LBTESTCD, LBTEST, LBSPEC, LBSTRESU, LBMETHOD, LBPTFL, LBRESTYP, LBRESSCL | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| loinc_input = gr.Textbox( | |
| label="LOINC Code (required)", | |
| placeholder="e.g., 1558-6", | |
| value="1558-6" | |
| ) | |
| component_input = gr.Textbox( | |
| label="Component (optional)", | |
| placeholder="e.g., Glucose" | |
| ) | |
| property_input = gr.Textbox( | |
| label="Property (optional)", | |
| placeholder="e.g., Mass concentration" | |
| ) | |
| system_input = gr.Textbox( | |
| label="System (optional)", | |
| placeholder="e.g., Serum" | |
| ) | |
| submit_btn = gr.Button("Get SDTM Mapping", variant="primary") | |
| with gr.Column(): | |
| output = gr.Code( | |
| label="SDTM Mapping (JSON)", | |
| language="json" | |
| ) | |
| submit_btn.click( | |
| fn=predict_loinc_mapping, | |
| inputs=[loinc_input, component_input, property_input, system_input], | |
| outputs=output | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["1558-6", "Glucose", "Mass concentration", "Serum"], | |
| ["2345-7", "Glucose", "MCnc", "Serum or Plasma"], | |
| ["2160-0", "Creatinine", "MCnc", "Serum or Plasma"], | |
| ["883-9", "ABO group", "Type", "Blood"], | |
| ["718-7", "Hemoglobin", "MCnc", "Blood"], | |
| ], | |
| inputs=[loinc_input, component_input, property_input, system_input], | |
| label="Example LOINC Codes" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### About This Model | |
| This classifier was trained on FDA-approved LOINC-to-SDTM mappings using SAPBERT | |
| (a biomedical entity linking model based on PubMedBERT). | |
| **How it works:** | |
| 1. Combines LOINC code with metadata (component, property, system) | |
| 2. Encodes using SAPBERT (trained on PubMed biomedical text) | |
| 3. Predicts 8 SDTM fields simultaneously using separate classifier heads | |
| **Expected accuracy:** 95-98% based on training metrics | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |