adilsiraju commited on
Commit
09d0e11
·
1 Parent(s): fa63877
app.py CHANGED
@@ -1,59 +1,85 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- # Define the candidate labels for classification
5
- medical_specialties = [
6
- "Cardiovascular Pulmonary",
7
- "Orthopedic",
8
- "Nephrology",
9
- "ENT Otolaryngology",
10
- "Obstetrics Gynecology",
11
- "Ophthalmology",
12
- "Gastroenterology",
13
- "Neurology",
14
- "Radiology",
15
- "Psychiatry Psychology",
16
- "Pediatrics Neonatal",
17
- "Hematology Oncology",
18
- "Neurosurgery"
19
- ]
20
-
21
- # Initialize the zero-shot classification pipeline
22
- # A better-performing, fine-tuned model could be used here.
23
- classifier = pipeline(
24
- "zero-shot-classification",
25
- model="facebook/bart-large-mnli",
26
- device=-1 # Use -1 for CPU, or 0 for GPU if available
27
- )
28
 
29
- def classify_medical_text(text):
 
 
 
 
 
 
 
30
  """
31
- Classifies a medical text into one of the predefined medical specialties.
32
  """
33
- if not text:
34
- return {"Error": "Please provide some text to classify."}
 
 
 
35
 
36
- # Perform zero-shot classification
37
- result = classifier(text, medical_specialties)
 
 
 
 
 
 
38
 
39
- # Format the output for better display
40
- labels = result['labels']
41
- scores = result['scores']
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Return the results as a dictionary for Gradio to display
44
- return {label: score for label, score in zip(labels, scores)}
45
 
46
  # Create the Gradio interface
47
  iface = gr.Interface(
48
- fn=classify_medical_text,
49
  inputs=gr.Textbox(
50
  lines=10,
51
  placeholder="Paste a medical document or text here...",
52
  label="Medical Text"
53
  ),
54
- outputs=gr.Label(num_top_classes=len(medical_specialties)),
55
  title="Medical Document Classifier",
56
- description="This application uses a zero-shot classification model to predict the medical specialty of a given text."
57
  )
58
 
59
  # Launch the interface
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import pickle
5
+
6
+ # Load the saved model, tokenizer, and label encoder
7
+ try:
8
+ # Use the correct path where you saved your model
9
+ model_path = "./medical_classifier_model"
10
+
11
+ # Check for GPU availability
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Load the model and move it to the correct device
15
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
16
+ model.to(device)
17
+
18
+ # Load the tokenizer
19
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
20
+
21
+ # Load the label encoder
22
+ with open(f'{model_path}/label_encoder.pkl', 'rb') as f:
23
+ label_encoder = pickle.load(f)
24
+
25
+ # Get the class names from the label encoder
26
+ class_names = list(label_encoder.classes_)
 
27
 
28
+ print("Model, tokenizer, and label encoder loaded successfully!")
29
+
30
+ except Exception as e:
31
+ print(f"Error loading model components: {e}")
32
+ # Fallback or exit if loading fails
33
+ model, tokenizer, label_encoder, class_names = None, None, None, []
34
+
35
+ def predict_medical_specialty(text):
36
  """
37
+ Predicts the medical specialty of a given text using the fine-tuned model.
38
  """
39
+ if not text or not all([model, tokenizer, label_encoder]):
40
+ return {"Error": "Model not loaded correctly. Please check server logs."}
41
+
42
+ # Ensure the model is in evaluation mode
43
+ model.eval()
44
 
45
+ # Tokenize the input text and prepare it for the model
46
+ inputs = tokenizer(
47
+ text,
48
+ truncation=True,
49
+ padding="max_length",
50
+ max_length=128,
51
+ return_tensors="pt"
52
+ ).to(device) # Move the input tensors to the same device as the model
53
 
54
+ with torch.no_grad():
55
+ # Get model outputs
56
+ outputs = model(**inputs)
57
+
58
+ # Apply softmax to get probabilities
59
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
60
+
61
+ # Get the top class predictions and their scores
62
+ scores, indices = torch.topk(probabilities, k=len(class_names))
63
+
64
+ # Map the indices back to their original specialty names
65
+ predicted_labels = label_encoder.inverse_transform(indices.squeeze().cpu().numpy())
66
+
67
+ # Create a dictionary of results
68
+ result_dict = {label: score.item() for label, score in zip(predicted_labels, scores.squeeze())}
69
 
70
+ return result_dict
 
71
 
72
  # Create the Gradio interface
73
  iface = gr.Interface(
74
+ fn=predict_medical_specialty,
75
  inputs=gr.Textbox(
76
  lines=10,
77
  placeholder="Paste a medical document or text here...",
78
  label="Medical Text"
79
  ),
80
+ outputs=gr.Label(num_top_classes=len(class_names)),
81
  title="Medical Document Classifier",
82
+ description="This application uses a fine-tuned Bio_ClinicalBERT model to predict the medical specialty of a given text."
83
  )
84
 
85
  # Launch the interface
appcopy.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # Define the candidate labels for classification
5
+ medical_specialties = [
6
+ "Cardiovascular Pulmonary",
7
+ "Orthopedic",
8
+ "Nephrology",
9
+ "ENT Otolaryngology",
10
+ "Obstetrics Gynecology",
11
+ "Ophthalmology",
12
+ "Gastroenterology",
13
+ "Neurology",
14
+ "Radiology",
15
+ "Psychiatry Psychology",
16
+ "Pediatrics Neonatal",
17
+ "Hematology Oncology",
18
+ "Neurosurgery"
19
+ ]
20
+
21
+ # Initialize the zero-shot classification pipeline
22
+ # A better-performing, fine-tuned model could be used here.
23
+ classifier = pipeline(
24
+ "zero-shot-classification",
25
+ model="facebook/bart-large-mnli",
26
+ device=-1 # Use -1 for CPU, or 0 for GPU if available
27
+ )
28
+
29
+ def classify_medical_text(text):
30
+ """
31
+ Classifies a medical text into one of the predefined medical specialties.
32
+ """
33
+ if not text:
34
+ return {"Error": "Please provide some text to classify."}
35
+
36
+ # Perform zero-shot classification
37
+ result = classifier(text, medical_specialties)
38
+
39
+ # Format the output for better display
40
+ labels = result['labels']
41
+ scores = result['scores']
42
+
43
+ # Return the results as a dictionary for Gradio to display
44
+ return {label: score for label, score in zip(labels, scores)}
45
+
46
+ # Create the Gradio interface
47
+ iface = gr.Interface(
48
+ fn=classify_medical_text,
49
+ inputs=gr.Textbox(
50
+ lines=10,
51
+ placeholder="Paste a medical document or text here...",
52
+ label="Medical Text"
53
+ ),
54
+ outputs=gr.Label(num_top_classes=len(medical_specialties)),
55
+ title="Medical Document Classifier",
56
+ description="This application uses a zero-shot classification model to predict the medical specialty of a given text."
57
+ )
58
+
59
+ # Launch the interface
60
+ if __name__ == "__main__":
61
+ iface.launch()
medical_classifier_model/config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "dtype": "float32",
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2",
15
+ "3": "LABEL_3",
16
+ "4": "LABEL_4",
17
+ "5": "LABEL_5",
18
+ "6": "LABEL_6",
19
+ "7": "LABEL_7",
20
+ "8": "LABEL_8",
21
+ "9": "LABEL_9",
22
+ "10": "LABEL_10",
23
+ "11": "LABEL_11",
24
+ "12": "LABEL_12"
25
+ },
26
+ "initializer_range": 0.02,
27
+ "intermediate_size": 3072,
28
+ "label2id": {
29
+ "LABEL_0": 0,
30
+ "LABEL_1": 1,
31
+ "LABEL_10": 10,
32
+ "LABEL_11": 11,
33
+ "LABEL_12": 12,
34
+ "LABEL_2": 2,
35
+ "LABEL_3": 3,
36
+ "LABEL_4": 4,
37
+ "LABEL_5": 5,
38
+ "LABEL_6": 6,
39
+ "LABEL_7": 7,
40
+ "LABEL_8": 8,
41
+ "LABEL_9": 9
42
+ },
43
+ "layer_norm_eps": 1e-12,
44
+ "max_position_embeddings": 512,
45
+ "model_type": "bert",
46
+ "num_attention_heads": 12,
47
+ "num_hidden_layers": 12,
48
+ "pad_token_id": 0,
49
+ "position_embedding_type": "absolute",
50
+ "problem_type": "single_label_classification",
51
+ "transformers_version": "4.56.2",
52
+ "type_vocab_size": 2,
53
+ "use_cache": true,
54
+ "vocab_size": 28996
55
+ }
medical_classifier_model/label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bab7c256fe67b2cf75cc80ddcf92d92eda398d465ad84f1f4e2b1726306b3a2
3
+ size 1591
medical_classifier_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fde2b928afcf154ac6f041fef201e113e1fba1b34a58526364afa88481bf2d9
3
+ size 433304604
medical_classifier_model/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
medical_classifier_model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
medical_classifier_model/tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
medical_classifier_model/vocab.txt ADDED
The diff for this file is too large to render. See raw diff