Model Card: LSTM-BERT for Longitudinal Sequence Classification
Model Overview
Model Name: LSTM-BERT (Longitudinal Clinical Note Classifier)
Developer: NLP for Biomedical Information Analysis (NLP4BIA-BSC)
Model Type: Hierarchical Transformer-RNN Hybrid
Repository/API: Predictor API (GitHub)
Model Description
The LSTM-BERT is a hierarchical architecture designed to classify longitudinal sequences of text documents (specifically clinical "visits" or cases) accompanied by temporal data. Unlike standard BERT models that process a single block of text, this model treats a patient's history as a sequence of distinct events occurring over time.
Technical Architecture
The model processes data in two distinct stages:
Visit Encoding (RoBERTa):
- Each individual visit text sequence is encoded using a pre-trained RoBERTa model.
- Input: Token IDs of shape $(V, S)$ where $V$ is the number of visits and $S$ is the sequence length.
- Pooling: A weighted CLS pooling strategy is applied to extract a dense vector representation for each visit.
Longitudinal Aggregation (Bi-LSTM + Attention):
- Temporal Injection: The visit embeddings are concatenated with a projection of the visit dates (processed as log-deltas).
- Sequence Modeling: A Bidirectional LSTM processes the sequence of visit embeddings to capture temporal dependencies and progression.
- Attention Pooling: A learnable attention mechanism calculates a weight for each visit, allowing the model to focus on the most clinically relevant events in the history.
- Classification: The weighted sum of the LSTM outputs is passed through a linear classifier.
Required Files
To use the model, your local_model_path directory must contain:
config.jsonβLSTMBERTConfig(RobertaConfig-derived) matching the saved weightspytorch_model.binβ Model weights saved bytransformers(compatible withfrom_pretrained)- Tokenizer files (e.g.,
vocab.json,merges.txt, ortokenizer.json) compatible withRobertaTokenizer
Installation
pip install torch transformers
Recommended Project Structure:
project_root/
βββ app/
β βββ config/
β β βββ model_config.py # Contains LSTMBERTConfig class
β βββ models/
β β βββ base_model.py # Contains ModelClass
β β βββ utils.py # Contains date_linear_impute, dates_to_log_deltas
β β βββ modeling.py # Contains LSTMBERT and PredictionPipeline
βββ main.py # Your execution script
Quick Start
Basic Usage
from app.models.modeling import PredictionPipeline
# 1. Initialize the pipeline with the path to your trained model artifacts
model_path = "./saved_model"
pipeline = PredictionPipeline(local_model_path=model_path)
# 2. Prepare your input data
# 'case': A list of strings, where each string is the text of one visit/document
case_history = [
"Patient presents with mild headache and nausea.",
"Follow-up: Headache persists, recommended CT scan.",
"Scan results negative. Symptoms resolved."
]
# 'dates': A list of strings corresponding to the visits
# Format: "DDMonYYYY" (e.g., "10Jan2024")
# Use empty string "" for unknown/missing dates
visit_dates = [
"10Jan2024",
"15Jan2024",
"20Feb2024"
]
# 3. Run Inference
# Returns:
# - probability: Float (Probability of the positive class, e.g., class 1)
# - attention_weights: List[float] (Importance score for each visit)
probability, attention_weights = pipeline.predict(case_history, visit_dates)
print(f"Prediction Probability: {probability:.4f}")
print("Visit Importance:")
for date, weight in zip(visit_dates, attention_weights):
print(f" - {date}: {weight:.4f}")
Input/Output Specifications
Inputs
- Text (
case): AList[str]of length $V$ representing the number of visits. Each string is tokenized and truncated/padded tocfg.max_length(default: 128 tokens). - Dates (
dates): AList[str]matching the length of the text list. Format:%d%b%Y(e.g.,10Jan2024).- Use
""to indicate missing/unknown dates - The pipeline uses
date_linear_imputeanddates_to_log_deltasto convert dates into two floating-point features per visit - Internal representation: A tensor of shape $(V, 2)$ containing
(log_prev, log_start)β the log-delta since the previous visit and log-delta since the start of the history
- Use
Outputs
syn_prob(float): Softmax probability for the positive class (class index 1). If your model uses different label mapping, adjust accordingly.attn_list(List[float]): Attention weights over visits (sums to approximately 1.0), representing how much the model focused on each specific visit in the sequence.
Model Configuration
The model relies on a custom LSTMBERTConfig class inheriting from RobertaConfig. Your config.json must include these parameters:
| Parameter | Description | Example Value |
|---|---|---|
hidden_size |
RoBERTa hidden dimension | 768 |
max_length |
Maximum sequence length per visit | 128 |
lstm_hidden |
Hidden size of the LSTM layer | 256 |
lstm_layers |
Number of stacked LSTM layers | 1 |
attn_dim |
Dimension of the internal attention projection | 64 |
output_dim |
Number of classification labels | 2 |
visit_time_proj |
Dimension to project the 2 time features into before LSTM | 8 |
architectures |
Model class name | ["LSTMBERT"] |
Example config.json:
{
"hidden_size": 768,
"max_length": 128,
"lstm_hidden": 256,
"lstm_layers": 1,
"attn_dim": 64,
"output_dim": 2,
"visit_time_proj": 8,
"architectures": ["LSTMBERT"]
}
Note: Use the exact config.json produced at training time to ensure architecture compatibility.
Implementation Details
Device Handling
The pipeline automatically moves the model to CUDA if available; otherwise uses CPU. For deterministic CPU inference:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
# Then initialize pipeline
Internal Processing
- The pipeline uses
local_files_only=Truefor tokenizer/model loading β all files must be present locally - The model's
forwardmethod expectsinput_ids,attention_mask, and requiresvisit_timesshaped $(V, 2)$ - Attention weights are returned as a Python list of floats from
SequenceClassifierOutput
Troubleshooting
Common Errors
ValueError: visit_times shape must be (V, 2)
- Cause: Mismatch between
dateslength andcaselength, orformat_datesproduced incorrect shape - Solution: Ensure
len(dates) == len(case)and all dates follow the%d%b%Yformat
Tokenizer errors
- Cause: Missing or incompatible tokenizer files
- Solution: Confirm all tokenizer files (
vocab.json,merges.txtortokenizer.json) are present and compatible withRobertaTokenizer
Mismatched config/weights
- Cause:
config.jsondoesn't match the saved model weights - Solution: Ensure the config matches the training configuration (hidden sizes, number of labels, architecture parameters)
Model not found or loading errors
- Cause: Incorrect
local_model_pathor missing model files - Solution: Verify the directory contains
config.json,pytorch_model.bin, and tokenizer files
Validation
Minimal Unit Test (Smoke Test)
def smoke_test(local_model_path):
"""Basic validation that the model loads and runs"""
pipe = PredictionPipeline(local_model_path)
case = ['Hello world']
dates = ['01Jan2024']
p, a = pipe.predict(case, dates)
# Validate outputs
assert 0.0 <= p <= 1.0, "Probability must be between 0 and 1"
assert isinstance(a, list), "Attention must be a list"
assert len(a) == len(case), "Attention length must match case length"
assert abs(sum(a) - 1.0) < 0.01, "Attention weights should sum to ~1.0"
print("β Smoke test passed")
# Run test
smoke_test('./saved_model')
Full Integration Test
def integration_test(local_model_path):
"""Test with realistic clinical scenario"""
pipe = PredictionPipeline(local_model_path)
case = [
"Pt reports cough and fever, started 2 days ago.",
"Follow-up: symptoms improving after antitussive.",
"Resolved, patient discharged."
]
dates = ["10Jan2024", "12Jan2024", "15Jan2024"]
prob, attn = pipe.predict(case=case, dates=dates)
print(f"Positive probability: {prob:.4f}")
print("Attention weights per visit:")
for i, (date, weight) in enumerate(zip(dates, attn)):
print(f" Visit {i+1} ({date}): {weight:.4f}")
# Run test
integration_test('./saved_model')
API Reference
This model is served and integrated with a REST API developed in the predictor repository. For production deployment, API documentation, and advanced usage patterns, see:
Repository: https://github.com/nlp4bia-bsc/predictor_api
Contact & Support
Developer: NLP for Biomedical Information Analysis (NLP4BIA-BSC)
Issues: Please report issues through the GitHub repository
API Support: See the predictor_api repository for integration questions
- Downloads last month
- 19