File size: 4,438 Bytes
1718df6 09d0e11 079da61 09d0e11 079da61 09d0e11 fa63877 09d0e11 079da61 09d0e11 fa63877 09d0e11 fa63877 3d89652 09d0e11 fa63877 09d0e11 fa63877 09d0e11 12b9611 fa63877 c22378f fa63877 3d89652 079da61 fa63877 09d0e11 fa63877 3d89652 079da61 3d89652 fa63877 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pickle
# Initialize variables
model, tokenizer, label_encoder, class_names = None, None, None, []
# Load the saved model, tokenizer, and label encoder
try:
# Check if sklearn is available
try:
import sklearn
print("scikit-learn is available")
except ImportError:
raise ImportError("scikit-learn is not installed. Please install it using: pip install scikit-learn")
# Use the correct path where you saved your model
model_path = "./medical_classifier_model"
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model and move it to the correct device
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.to(device)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load the label encoder
with open(f'{model_path}/label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
# Get the class names from the label encoder
class_names = list(label_encoder.classes_)
print("Model, tokenizer, and label encoder loaded successfully!")
except Exception as e:
print(f"Error loading model components: {e}")
print("The application will run in fallback mode with limited functionality.")
# Fallback values when loading fails
model, tokenizer, label_encoder, class_names = None, None, None, []
def predict_medical_specialty(text):
"""
Predicts the medical specialty of a given text using the fine-tuned model.
"""
if not text:
return {"No input": 1.0}
if not all([model, tokenizer, label_encoder]):
return {"Model Error": 1.0}
# Ensure the model is in evaluation mode
model.eval()
# Tokenize the input text and prepare it for the model
inputs = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
).to(device) # Move the input tensors to the same device as the model
with torch.no_grad():
# Get model outputs
outputs = model(**inputs)
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get the top class predictions and their scores
top_k = min(3, len(class_names))
scores, indices = torch.topk(probabilities, k=top_k)
# Map the indices back to their original specialty names
predicted_labels = label_encoder.inverse_transform(indices.squeeze().cpu().numpy())
# Create a dictionary of top 3 results
result_dict = {label: score.item() for label, score in zip(predicted_labels, scores.squeeze())}
return result_dict
# Define example medical texts
examples = [
"Patient presenting with chest pain, shortness of breath, and palpitations. ECG shows atrial fibrillation.",
"Aspiration of the knee joint was performed due to swelling and suspected septic arthritis.",
"Post-operative report for a patient who underwent a hysterectomy due to uterine fibroids.",
"Neurological examination revealed a positive Babinski sign and nystagmus, suggesting a central nervous system disorder."
]
# Create the Gradio interface
# Determine the number of top classes to show and whether to include examples
num_classes = len(class_names) if class_names else 3
include_examples = examples if all([model, tokenizer, label_encoder]) else None
# Create appropriate description based on model availability
if all([model, tokenizer, label_encoder]):
description_text = "This application uses a fine-tuned Bio_ClinicalBERT model to predict the medical specialty of a given text."
else:
description_text = "⚠️ Model loading failed (likely missing scikit-learn dependency). Please ensure all requirements are installed. The app is running in fallback mode."
iface = gr.Interface(
fn=predict_medical_specialty,
inputs=gr.Textbox(
lines=10,
placeholder="Paste a medical document or text here...",
label="Medical Text"
),
outputs=gr.Label(num_top_classes=num_classes),
title="Medical Case Classifier",
description=description_text,
examples=include_examples
)
# Launch the interface
if __name__ == "__main__":
iface.launch() |