Spaces:
Sleeping
Sleeping
adilsiraju
Refactor prediction logic to return top 3 medical specialty predictions and remove unused zero-shot classification code
12b9611
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import pickle | |
| # Initialize variables | |
| model, tokenizer, label_encoder, class_names = None, None, None, [] | |
| # Load the saved model, tokenizer, and label encoder | |
| try: | |
| # Check if sklearn is available | |
| try: | |
| import sklearn | |
| print("scikit-learn is available") | |
| except ImportError: | |
| raise ImportError("scikit-learn is not installed. Please install it using: pip install scikit-learn") | |
| # Use the correct path where you saved your model | |
| model_path = "./medical_classifier_model" | |
| # Check for GPU availability | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the model and move it to the correct device | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| model.to(device) | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| # Load the label encoder | |
| with open(f'{model_path}/label_encoder.pkl', 'rb') as f: | |
| label_encoder = pickle.load(f) | |
| # Get the class names from the label encoder | |
| class_names = list(label_encoder.classes_) | |
| print("Model, tokenizer, and label encoder loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model components: {e}") | |
| print("The application will run in fallback mode with limited functionality.") | |
| # Fallback values when loading fails | |
| model, tokenizer, label_encoder, class_names = None, None, None, [] | |
| def predict_medical_specialty(text): | |
| """ | |
| Predicts the medical specialty of a given text using the fine-tuned model. | |
| """ | |
| if not text: | |
| return {"No input": 1.0} | |
| if not all([model, tokenizer, label_encoder]): | |
| return {"Model Error": 1.0} | |
| # Ensure the model is in evaluation mode | |
| model.eval() | |
| # Tokenize the input text and prepare it for the model | |
| inputs = tokenizer( | |
| text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=128, | |
| return_tensors="pt" | |
| ).to(device) # Move the input tensors to the same device as the model | |
| with torch.no_grad(): | |
| # Get model outputs | |
| outputs = model(**inputs) | |
| # Apply softmax to get probabilities | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Get the top class predictions and their scores | |
| top_k = min(3, len(class_names)) | |
| scores, indices = torch.topk(probabilities, k=top_k) | |
| # Map the indices back to their original specialty names | |
| predicted_labels = label_encoder.inverse_transform(indices.squeeze().cpu().numpy()) | |
| # Create a dictionary of top 3 results | |
| result_dict = {label: score.item() for label, score in zip(predicted_labels, scores.squeeze())} | |
| return result_dict | |
| # Define example medical texts | |
| examples = [ | |
| "Patient presenting with chest pain, shortness of breath, and palpitations. ECG shows atrial fibrillation.", | |
| "Aspiration of the knee joint was performed due to swelling and suspected septic arthritis.", | |
| "Post-operative report for a patient who underwent a hysterectomy due to uterine fibroids.", | |
| "Neurological examination revealed a positive Babinski sign and nystagmus, suggesting a central nervous system disorder." | |
| ] | |
| # Create the Gradio interface | |
| # Determine the number of top classes to show and whether to include examples | |
| num_classes = len(class_names) if class_names else 3 | |
| include_examples = examples if all([model, tokenizer, label_encoder]) else None | |
| # Create appropriate description based on model availability | |
| if all([model, tokenizer, label_encoder]): | |
| description_text = "This application uses a fine-tuned Bio_ClinicalBERT model to predict the medical specialty of a given text." | |
| else: | |
| description_text = "⚠️ Model loading failed (likely missing scikit-learn dependency). Please ensure all requirements are installed. The app is running in fallback mode." | |
| iface = gr.Interface( | |
| fn=predict_medical_specialty, | |
| inputs=gr.Textbox( | |
| lines=10, | |
| placeholder="Paste a medical document or text here...", | |
| label="Medical Text" | |
| ), | |
| outputs=gr.Label(num_top_classes=num_classes), | |
| title="Medical Case Classifier", | |
| description=description_text, | |
| examples=include_examples | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.launch() |