Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,6 @@ medical_specialties = [
|
|
| 19 |
]
|
| 20 |
|
| 21 |
# Initialize the zero-shot classification pipeline
|
| 22 |
-
# A better-performing, fine-tuned model could be used here.
|
| 23 |
classifier = pipeline(
|
| 24 |
"zero-shot-classification",
|
| 25 |
model="facebook/bart-large-mnli",
|
|
@@ -28,7 +27,7 @@ classifier = pipeline(
|
|
| 28 |
|
| 29 |
def classify_medical_text(text):
|
| 30 |
"""
|
| 31 |
-
Classifies a medical text into one of the predefined medical specialties.
|
| 32 |
"""
|
| 33 |
if not text:
|
| 34 |
return {"Error": "Please provide some text to classify."}
|
|
@@ -36,12 +35,24 @@ def classify_medical_text(text):
|
|
| 36 |
# Perform zero-shot classification
|
| 37 |
result = classifier(text, medical_specialties)
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
-
scores = result['scores']
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# Create the Gradio interface
|
| 47 |
iface = gr.Interface(
|
|
@@ -51,9 +62,10 @@ iface = gr.Interface(
|
|
| 51 |
placeholder="Paste a medical document or text here...",
|
| 52 |
label="Medical Text"
|
| 53 |
),
|
| 54 |
-
outputs=gr.Label(num_top_classes=
|
| 55 |
title="Medical Document Classifier",
|
| 56 |
-
description="This application uses a zero-shot classification model to predict the medical specialty of a given text."
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
# Launch the interface
|
|
|
|
| 19 |
]
|
| 20 |
|
| 21 |
# Initialize the zero-shot classification pipeline
|
|
|
|
| 22 |
classifier = pipeline(
|
| 23 |
"zero-shot-classification",
|
| 24 |
model="facebook/bart-large-mnli",
|
|
|
|
| 27 |
|
| 28 |
def classify_medical_text(text):
|
| 29 |
"""
|
| 30 |
+
Classifies a medical text into one of the predefined medical specialties and returns the top 3 predictions.
|
| 31 |
"""
|
| 32 |
if not text:
|
| 33 |
return {"Error": "Please provide some text to classify."}
|
|
|
|
| 35 |
# Perform zero-shot classification
|
| 36 |
result = classifier(text, medical_specialties)
|
| 37 |
|
| 38 |
+
# Combine labels and scores
|
| 39 |
+
combined_results = zip(result['labels'], result['scores'])
|
|
|
|
| 40 |
|
| 41 |
+
# Sort the results by score in descending order and get the top 3
|
| 42 |
+
top_3_predictions = sorted(combined_results, key=lambda x: x[1], reverse=True)[:3]
|
| 43 |
+
|
| 44 |
+
# Format the output as a dictionary for Gradio
|
| 45 |
+
top_3_dict = {label: score for label, score in top_3_predictions}
|
| 46 |
+
|
| 47 |
+
return top_3_dict
|
| 48 |
+
|
| 49 |
+
# Define example medical texts
|
| 50 |
+
examples = [
|
| 51 |
+
"Patient presenting with chest pain, shortness of breath, and palpitations. ECG shows atrial fibrillation.",
|
| 52 |
+
"Aspiration of the knee joint was performed due to swelling and suspected septic arthritis.",
|
| 53 |
+
"Post-operative report for a patient who underwent a hysterectomy due to uterine fibroids.",
|
| 54 |
+
"Neurological examination revealed a positive Babinski sign and nystagmus, suggesting a central nervous system disorder."
|
| 55 |
+
]
|
| 56 |
|
| 57 |
# Create the Gradio interface
|
| 58 |
iface = gr.Interface(
|
|
|
|
| 62 |
placeholder="Paste a medical document or text here...",
|
| 63 |
label="Medical Text"
|
| 64 |
),
|
| 65 |
+
outputs=gr.Label(num_top_classes=3),
|
| 66 |
title="Medical Document Classifier",
|
| 67 |
+
description="This application uses a zero-shot classification model to predict the medical specialty of a given text. It will display the top 3 most likely specialties. Click on an example below to get started!",
|
| 68 |
+
examples=examples
|
| 69 |
)
|
| 70 |
|
| 71 |
# Launch the interface
|