import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import pickle # Initialize variables model, tokenizer, label_encoder, class_names = None, None, None, [] # Load the saved model, tokenizer, and label encoder try: # Check if sklearn is available try: import sklearn print("scikit-learn is available") except ImportError: raise ImportError("scikit-learn is not installed. Please install it using: pip install scikit-learn") # Use the correct path where you saved your model model_path = "./medical_classifier_model" # Check for GPU availability device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model and move it to the correct device model = AutoModelForSequenceClassification.from_pretrained(model_path) model.to(device) # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path) # Load the label encoder with open(f'{model_path}/label_encoder.pkl', 'rb') as f: label_encoder = pickle.load(f) # Get the class names from the label encoder class_names = list(label_encoder.classes_) print("Model, tokenizer, and label encoder loaded successfully!") except Exception as e: print(f"Error loading model components: {e}") print("The application will run in fallback mode with limited functionality.") # Fallback values when loading fails model, tokenizer, label_encoder, class_names = None, None, None, [] def predict_medical_specialty(text): """ Predicts the medical specialty of a given text using the fine-tuned model. """ if not text: return {"No input": 1.0} if not all([model, tokenizer, label_encoder]): return {"Model Error": 1.0} # Ensure the model is in evaluation mode model.eval() # Tokenize the input text and prepare it for the model inputs = tokenizer( text, truncation=True, padding="max_length", max_length=128, return_tensors="pt" ).to(device) # Move the input tensors to the same device as the model with torch.no_grad(): # Get model outputs outputs = model(**inputs) # Apply softmax to get probabilities probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get the top class predictions and their scores top_k = min(3, len(class_names)) scores, indices = torch.topk(probabilities, k=top_k) # Map the indices back to their original specialty names predicted_labels = label_encoder.inverse_transform(indices.squeeze().cpu().numpy()) # Create a dictionary of top 3 results result_dict = {label: score.item() for label, score in zip(predicted_labels, scores.squeeze())} return result_dict # Define example medical texts examples = [ "Patient presenting with chest pain, shortness of breath, and palpitations. ECG shows atrial fibrillation.", "Aspiration of the knee joint was performed due to swelling and suspected septic arthritis.", "Post-operative report for a patient who underwent a hysterectomy due to uterine fibroids.", "Neurological examination revealed a positive Babinski sign and nystagmus, suggesting a central nervous system disorder." ] # Create the Gradio interface # Determine the number of top classes to show and whether to include examples num_classes = len(class_names) if class_names else 3 include_examples = examples if all([model, tokenizer, label_encoder]) else None # Create appropriate description based on model availability if all([model, tokenizer, label_encoder]): description_text = "This application uses a fine-tuned Bio_ClinicalBERT model to predict the medical specialty of a given text." else: description_text = "⚠️ Model loading failed (likely missing scikit-learn dependency). Please ensure all requirements are installed. The app is running in fallback mode." iface = gr.Interface( fn=predict_medical_specialty, inputs=gr.Textbox( lines=10, placeholder="Paste a medical document or text here...", label="Medical Text" ), outputs=gr.Label(num_top_classes=num_classes), title="Medical Case Classifier", description=description_text, examples=include_examples ) # Launch the interface if __name__ == "__main__": iface.launch()