File size: 4,438 Bytes
1718df6
09d0e11
 
 
 
079da61
 
 
09d0e11
 
079da61
 
 
 
 
 
 
09d0e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa63877
09d0e11
 
 
 
079da61
 
09d0e11
 
 
fa63877
09d0e11
fa63877
3d89652
 
 
 
 
09d0e11
 
 
fa63877
09d0e11
 
 
 
 
 
 
 
fa63877
09d0e11
 
 
 
 
 
 
 
12b9611
 
 
 
 
 
 
fa63877
c22378f
 
 
 
 
 
 
fa63877
 
3d89652
 
 
 
079da61
 
 
 
 
 
fa63877
09d0e11
fa63877
 
 
 
 
3d89652
 
079da61
3d89652
fa63877
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pickle

# Initialize variables
model, tokenizer, label_encoder, class_names = None, None, None, []

# Load the saved model, tokenizer, and label encoder
try:
    # Check if sklearn is available
    try:
        import sklearn
        print("scikit-learn is available")
    except ImportError:
        raise ImportError("scikit-learn is not installed. Please install it using: pip install scikit-learn")
    
    # Use the correct path where you saved your model
    model_path = "./medical_classifier_model"
    
    # Check for GPU availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the model and move it to the correct device
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    model.to(device)

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # Load the label encoder
    with open(f'{model_path}/label_encoder.pkl', 'rb') as f:
        label_encoder = pickle.load(f)

    # Get the class names from the label encoder
    class_names = list(label_encoder.classes_)

    print("Model, tokenizer, and label encoder loaded successfully!")

except Exception as e:
    print(f"Error loading model components: {e}")
    print("The application will run in fallback mode with limited functionality.")
    # Fallback values when loading fails
    model, tokenizer, label_encoder, class_names = None, None, None, []

def predict_medical_specialty(text):
    """
    Predicts the medical specialty of a given text using the fine-tuned model.
    """
    if not text:
        return {"No input": 1.0}
    
    if not all([model, tokenizer, label_encoder]):
        return {"Model Error": 1.0}

    # Ensure the model is in evaluation mode
    model.eval()

    # Tokenize the input text and prepare it for the model
    inputs = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt"
    ).to(device) # Move the input tensors to the same device as the model

    with torch.no_grad():
        # Get model outputs
        outputs = model(**inputs)
        
        # Apply softmax to get probabilities
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)

    # Get the top class predictions and their scores
        top_k = min(3, len(class_names))
        scores, indices = torch.topk(probabilities, k=top_k)
        # Map the indices back to their original specialty names
        predicted_labels = label_encoder.inverse_transform(indices.squeeze().cpu().numpy())
        # Create a dictionary of top 3 results
        result_dict = {label: score.item() for label, score in zip(predicted_labels, scores.squeeze())}
        return result_dict

# Define example medical texts
examples = [
    "Patient presenting with chest pain, shortness of breath, and palpitations. ECG shows atrial fibrillation.",
    "Aspiration of the knee joint was performed due to swelling and suspected septic arthritis.",
    "Post-operative report for a patient who underwent a hysterectomy due to uterine fibroids.",
    "Neurological examination revealed a positive Babinski sign and nystagmus, suggesting a central nervous system disorder."
]

# Create the Gradio interface
# Determine the number of top classes to show and whether to include examples
num_classes = len(class_names) if class_names else 3
include_examples = examples if all([model, tokenizer, label_encoder]) else None

# Create appropriate description based on model availability
if all([model, tokenizer, label_encoder]):
    description_text = "This application uses a fine-tuned Bio_ClinicalBERT model to predict the medical specialty of a given text."
else:
    description_text = "⚠️ Model loading failed (likely missing scikit-learn dependency). Please ensure all requirements are installed. The app is running in fallback mode."

iface = gr.Interface(
    fn=predict_medical_specialty,
    inputs=gr.Textbox(
        lines=10, 
        placeholder="Paste a medical document or text here...", 
        label="Medical Text"
    ),
    outputs=gr.Label(num_top_classes=num_classes),
    title="Medical Case Classifier",
    description=description_text,
    examples=include_examples
)

# Launch the interface
if __name__ == "__main__":
    iface.launch()