adilsiraju
Refactor prediction logic to return top 3 medical specialty predictions and remove unused zero-shot classification code
12b9611
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pickle
# Initialize variables
model, tokenizer, label_encoder, class_names = None, None, None, []
# Load the saved model, tokenizer, and label encoder
try:
# Check if sklearn is available
try:
import sklearn
print("scikit-learn is available")
except ImportError:
raise ImportError("scikit-learn is not installed. Please install it using: pip install scikit-learn")
# Use the correct path where you saved your model
model_path = "./medical_classifier_model"
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model and move it to the correct device
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.to(device)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load the label encoder
with open(f'{model_path}/label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
# Get the class names from the label encoder
class_names = list(label_encoder.classes_)
print("Model, tokenizer, and label encoder loaded successfully!")
except Exception as e:
print(f"Error loading model components: {e}")
print("The application will run in fallback mode with limited functionality.")
# Fallback values when loading fails
model, tokenizer, label_encoder, class_names = None, None, None, []
def predict_medical_specialty(text):
"""
Predicts the medical specialty of a given text using the fine-tuned model.
"""
if not text:
return {"No input": 1.0}
if not all([model, tokenizer, label_encoder]):
return {"Model Error": 1.0}
# Ensure the model is in evaluation mode
model.eval()
# Tokenize the input text and prepare it for the model
inputs = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
).to(device) # Move the input tensors to the same device as the model
with torch.no_grad():
# Get model outputs
outputs = model(**inputs)
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get the top class predictions and their scores
top_k = min(3, len(class_names))
scores, indices = torch.topk(probabilities, k=top_k)
# Map the indices back to their original specialty names
predicted_labels = label_encoder.inverse_transform(indices.squeeze().cpu().numpy())
# Create a dictionary of top 3 results
result_dict = {label: score.item() for label, score in zip(predicted_labels, scores.squeeze())}
return result_dict
# Define example medical texts
examples = [
"Patient presenting with chest pain, shortness of breath, and palpitations. ECG shows atrial fibrillation.",
"Aspiration of the knee joint was performed due to swelling and suspected septic arthritis.",
"Post-operative report for a patient who underwent a hysterectomy due to uterine fibroids.",
"Neurological examination revealed a positive Babinski sign and nystagmus, suggesting a central nervous system disorder."
]
# Create the Gradio interface
# Determine the number of top classes to show and whether to include examples
num_classes = len(class_names) if class_names else 3
include_examples = examples if all([model, tokenizer, label_encoder]) else None
# Create appropriate description based on model availability
if all([model, tokenizer, label_encoder]):
description_text = "This application uses a fine-tuned Bio_ClinicalBERT model to predict the medical specialty of a given text."
else:
description_text = "⚠️ Model loading failed (likely missing scikit-learn dependency). Please ensure all requirements are installed. The app is running in fallback mode."
iface = gr.Interface(
fn=predict_medical_specialty,
inputs=gr.Textbox(
lines=10,
placeholder="Paste a medical document or text here...",
label="Medical Text"
),
outputs=gr.Label(num_top_classes=num_classes),
title="Medical Case Classifier",
description=description_text,
examples=include_examples
)
# Launch the interface
if __name__ == "__main__":
iface.launch()