Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import json | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from sklearn.preprocessing import MultiLabelBinarizer | |
| import numpy as np | |
| import re | |
| # ---------------------------------------------------------------------- | |
| # Text Preprocessing (same as during training) | |
| # ---------------------------------------------------------------------- | |
| def preprocess_text(text: str) -> str: | |
| text = text.lower() | |
| text = re.sub(r"\[\*\*.*?\*\*\]", " ", text) | |
| text = re.sub(r"([!?.,])\1+", r"\1", text) | |
| text = re.sub(r"[\r\n\t]+", " ", text) | |
| text = re.sub(r"\s+", " ", text) | |
| text = text.strip() | |
| return text | |
| # ---------------------------------------------------------------------- | |
| # Load Trained Model and Artifacts | |
| # ---------------------------------------------------------------------- | |
| def load_trained_model(model_dir: str): | |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| with open(f"{model_dir}/mlb_classes.json", "r") as f: | |
| top_codes_list = json.load(f) | |
| mlb = MultiLabelBinarizer(classes=top_codes_list) | |
| mlb.fit([[]]) | |
| return model, tokenizer, mlb | |
| # ---------------------------------------------------------------------- | |
| # Predict ICD-9 Codes | |
| # ---------------------------------------------------------------------- | |
| def predict_icd9(input_text: str, model, tokenizer, mlb, max_length=512, threshold=0.5): | |
| processed_text = preprocess_text(input_text) | |
| inputs = tokenizer( | |
| processed_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_length, | |
| padding="max_length" | |
| ) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.sigmoid(logits).squeeze().cpu().numpy() | |
| y_pred = (probs > threshold).astype(int) | |
| predicted_codes = mlb.inverse_transform(np.array([y_pred])) | |
| return predicted_codes[0] | |
| # ---------------------------------------------------------------------- | |
| # Streamlit App | |
| # ---------------------------------------------------------------------- | |
| st.title("ICD-9 Code Prediction") | |
| model_dir = "./final_mode4l" | |
| st.sidebar.header("Model Settings") | |
| threshold = st.sidebar.slider("Prediction Threshold", min_value=0.1, max_value=1.0, value=0.5, step=0.1) | |
| st.write("Enter clinical text below to predict ICD-9 codes.") | |
| input_text = st.text_area("Clinical Text", height=200) | |
| if st.button("Predict"): | |
| if not input_text.strip(): | |
| st.error("Please enter valid clinical text.") | |
| else: | |
| st.write("Loading model...") | |
| model, tokenizer, mlb = load_trained_model(model_dir) | |
| st.write("Predicting...") | |
| predicted_codes = predict_icd9(input_text, model, tokenizer, mlb, threshold=threshold) | |
| if predicted_codes: | |
| st.success("Predicted ICD-9 Codes:") | |
| st.write(predicted_codes) | |
| else: | |
| st.warning("No codes were predicted. Try lowering the threshold or using a different input.") | |