| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.preprocessing import MultiLabelBinarizer | |
| # Check if a GPU is available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the trained model and tokenizer | |
| def load_model(): | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract", | |
| num_labels=8, # Adjust based on your label count | |
| problem_type="multi_label_classification" | |
| ) | |
| model.load_state_dict(torch.load('best_model_v2.pth')) | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract") | |
| model = model.to(device) # Move the model to the correct device | |
| return model, tokenizer | |
| def load_mlb(): | |
| # Define the classes based on your label set | |
| # classes = ['E11.9', 'I10', 'J45.909', 'M54.5', 'N39.0', '81001.0', '99213.0', '99214.0'] | |
| classes = ['81001.0','99213.0','99214.0','E11.9','I10','J45.909','M54.5','N39.0'] | |
| # Initialize and fit the MultiLabelBinarizer | |
| mlb = MultiLabelBinarizer(classes=classes) | |
| mlb.fit([classes]) # Fit with the full list of labels as a single sample | |
| return mlb | |
| # # Load MultiLabelBinarizer | |
| # @st.cache_resource | |
| # def load_mlb(): | |
| # mlb = MultiLabelBinarizer() | |
| # # mlb.classes_ = np.load('mlb_classes.npy') # Assuming you saved the classes array during training | |
| # mlb = MultiLabelBinarizer(classes=['E11.9', 'I10', 'J45.909', 'M54.5', | |
| # 'N39.0', '81001.0', '99213.0', '99214.0']) # Update with actual labels | |
| # return mlb | |
| model, tokenizer = load_model() | |
| mlb = load_mlb() | |
| # Streamlit UI | |
| st.title("Automated Coding and Billing Prediction") | |
| # st.write("Enter clinical notes to predict ICD and CPT codes.") | |
| # Text input for Clinical Notes | |
| clinical_note = st.text_area("Enter clinical notes to predict ICD and CPT codes") | |
| # Prediction button | |
| if st.button('Predict'): | |
| if clinical_note: | |
| # Tokenize the input clinical note | |
| inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt') | |
| # Move inputs to the GPU if available | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| # Model inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Apply sigmoid and threshold the output (0.5 for multi-label classification) | |
| pred_labels = (torch.sigmoid(logits) > 0.5).cpu().numpy() | |
| # Get the predicted ICD and CPT codes | |
| predicted_codes = mlb.inverse_transform(pred_labels) | |
| # Format the results for better display | |
| if predicted_codes: | |
| st.write("**Predicted ICD and CPT Codes:**") | |
| for codes in predicted_codes: | |
| for code in codes: | |
| if code in ['81001.0', '99213.0', '99214.0']: # Adjust based on your CPT code list | |
| st.write(f"- **CPT Code:** {code}") | |
| else: | |
| st.write(f"- **ICD Code:** {code}") | |
| else: | |
| st.write("No codes predicted.") | |
| # else: | |
| # st.write("Please enter clinical notes for prediction.") | |
| # # Prediction button | |
| # if st.button('Predict'): | |
| # if clinical_note: | |
| # # Tokenize the input clinical note | |
| # inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt') | |
| # # Move inputs to the GPU if available | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # inputs = {key: val.to(device) for key, val in inputs.items()} | |
| # # Model inference | |
| # with torch.no_grad(): | |
| # outputs = model(**inputs) | |
| # logits = outputs.logits | |
| # # Apply sigmoid and threshold the output (0.5 for multi-label classification) | |
| # pred_labels = (torch.sigmoid(logits) > 0.5).cpu().numpy() | |
| # # Get the predicted ICD and CPT codes | |
| # predicted_codes = mlb.inverse_transform(pred_labels) | |
| # # Show the results | |
| # st.write("Predicted ICD and CPT Codes:") | |
| # st.write(predicted_codes) | |
| # else: | |
| # st.write("Please enter clinical notes for prediction.") | |