adilsiraju commited on
Commit
c22378f
·
verified ·
1 Parent(s): fa63877

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -19,7 +19,6 @@ medical_specialties = [
19
  ]
20
 
21
  # Initialize the zero-shot classification pipeline
22
- # A better-performing, fine-tuned model could be used here.
23
  classifier = pipeline(
24
  "zero-shot-classification",
25
  model="facebook/bart-large-mnli",
@@ -28,7 +27,7 @@ classifier = pipeline(
28
 
29
  def classify_medical_text(text):
30
  """
31
- Classifies a medical text into one of the predefined medical specialties.
32
  """
33
  if not text:
34
  return {"Error": "Please provide some text to classify."}
@@ -36,12 +35,24 @@ def classify_medical_text(text):
36
  # Perform zero-shot classification
37
  result = classifier(text, medical_specialties)
38
 
39
- # Format the output for better display
40
- labels = result['labels']
41
- scores = result['scores']
42
 
43
- # Return the results as a dictionary for Gradio to display
44
- return {label: score for label, score in zip(labels, scores)}
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Create the Gradio interface
47
  iface = gr.Interface(
@@ -51,9 +62,10 @@ iface = gr.Interface(
51
  placeholder="Paste a medical document or text here...",
52
  label="Medical Text"
53
  ),
54
- outputs=gr.Label(num_top_classes=len(medical_specialties)),
55
  title="Medical Document Classifier",
56
- description="This application uses a zero-shot classification model to predict the medical specialty of a given text."
 
57
  )
58
 
59
  # Launch the interface
 
19
  ]
20
 
21
  # Initialize the zero-shot classification pipeline
 
22
  classifier = pipeline(
23
  "zero-shot-classification",
24
  model="facebook/bart-large-mnli",
 
27
 
28
  def classify_medical_text(text):
29
  """
30
+ Classifies a medical text into one of the predefined medical specialties and returns the top 3 predictions.
31
  """
32
  if not text:
33
  return {"Error": "Please provide some text to classify."}
 
35
  # Perform zero-shot classification
36
  result = classifier(text, medical_specialties)
37
 
38
+ # Combine labels and scores
39
+ combined_results = zip(result['labels'], result['scores'])
 
40
 
41
+ # Sort the results by score in descending order and get the top 3
42
+ top_3_predictions = sorted(combined_results, key=lambda x: x[1], reverse=True)[:3]
43
+
44
+ # Format the output as a dictionary for Gradio
45
+ top_3_dict = {label: score for label, score in top_3_predictions}
46
+
47
+ return top_3_dict
48
+
49
+ # Define example medical texts
50
+ examples = [
51
+ "Patient presenting with chest pain, shortness of breath, and palpitations. ECG shows atrial fibrillation.",
52
+ "Aspiration of the knee joint was performed due to swelling and suspected septic arthritis.",
53
+ "Post-operative report for a patient who underwent a hysterectomy due to uterine fibroids.",
54
+ "Neurological examination revealed a positive Babinski sign and nystagmus, suggesting a central nervous system disorder."
55
+ ]
56
 
57
  # Create the Gradio interface
58
  iface = gr.Interface(
 
62
  placeholder="Paste a medical document or text here...",
63
  label="Medical Text"
64
  ),
65
+ outputs=gr.Label(num_top_classes=3),
66
  title="Medical Document Classifier",
67
+ description="This application uses a zero-shot classification model to predict the medical specialty of a given text. It will display the top 3 most likely specialties. Click on an example below to get started!",
68
+ examples=examples
69
  )
70
 
71
  # Launch the interface