Spaces:
Sleeping
Sleeping
| # Import required libraries | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import numpy as np | |
| import io | |
| import matplotlib.pyplot as plt | |
| import warnings | |
| import os | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings("ignore") | |
| # --- 1. Load Model and Processor at Startup --- | |
| print("Loading model and processor...") | |
| try: | |
| # Use the model name from your working example | |
| MODEL_NAME = "codewithdark/vit-chest-xray" | |
| # Load the processor and model from Hugging Face | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| # Check for GPU and move the model if available for faster inference | |
| if torch.cuda.is_available(): | |
| print(f"GPU available: {torch.cuda.get_device_name(0)}. Moving model to GPU.") | |
| model.to("cuda") | |
| else: | |
| print("Warning: GPU not available, using CPU. Inference may be slow.") | |
| # Define the labels for classification | |
| LABEL_COLUMNS = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding'] | |
| print("Model and processor loaded successfully.") | |
| except Exception as e: | |
| # Handle errors during model loading | |
| print(f"Fatal Error: Could not load model or processor. {e}") | |
| processor, model = None, None | |
| # --- 2. Define Prediction and Visualization Functions --- | |
| def predict_xray(image: Image.Image) -> dict: | |
| """ | |
| Takes a PIL Image, preprocesses it, and returns the predicted label and confidence. | |
| """ | |
| if model is None or processor is None: | |
| return {"error": "Model not loaded. Please restart the application."} | |
| # Ensure the image is in RGB format, as required by the model | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Prepare the image for the model | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Move inputs to GPU if available | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| # Perform inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Get probabilities and find the top prediction | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=-1) | |
| confidence = torch.max(probabilities).item() | |
| predicted_class_idx = torch.argmax(probabilities, dim=-1).item() | |
| # Map index to label | |
| predicted_label = LABEL_COLUMNS[predicted_class_idx] | |
| return {"label": predicted_label, "confidence": confidence} | |
| def visualize_results(image: Image.Image, result: dict) -> Image.Image: | |
| """ | |
| Overlays the prediction results on the input image for visualization. | |
| """ | |
| image_viz = image.copy().convert("RGB") | |
| img_array = np.array(image_viz) | |
| plt.figure(figsize=(6, 6)) | |
| plt.imshow(img_array) | |
| plt.axis("off") | |
| # Create the text to display on the image | |
| text = f"{result['label']}: {result['confidence']:.2%}" | |
| plt.text(10, 30, text, color="white", fontsize=12, | |
| bbox=dict(facecolor="black", alpha=0.6)) | |
| # Save the plot to a memory buffer and return as a PIL Image | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close() | |
| buffer.seek(0) | |
| return Image.open(buffer) | |
| # --- 3. Main Function for Gradio --- | |
| def analyze_chest_xray(image: Image.Image) -> tuple: | |
| """ | |
| The main function that connects the UI to the backend prediction logic. | |
| """ | |
| if image is None: | |
| return None, "Please upload an image to analyze." | |
| # Get the prediction | |
| prediction_result = predict_xray(image) | |
| if "error" in prediction_result: | |
| return None, prediction_result["error"] | |
| # Create the text output | |
| text_result = f"Prediction: {prediction_result['label']}\nConfidence: {prediction_result['confidence']:.2%}" | |
| # Create the annotated image for display | |
| annotated_image = visualize_results(image, prediction_result) | |
| return annotated_image, text_result | |
| # --- 4. Define and Launch the Gradio Interface --- | |
| # Spotify-inspired CSS for a modern, dark theme | |
| custom_css = """ | |
| .gradio-container { background: #121212; font-family: 'Circular', 'Helvetica Neue', sans-serif; color: #ffffff; } | |
| .title { font-size: 2.5em; color: #1db954; text-align: center; margin-bottom: 10px; font-weight: bold; } | |
| .gr-button { background: #1db954 !important; color: #000000 !important; border-radius: 500px !important; font-weight: 700 !important; text-transform: uppercase; } | |
| .gr-button:hover { background: #1ed760 !important; transform: scale(1.05); } | |
| .output-image, .input-image { border: 2px solid #1db954 !important; border-radius: 12px !important; box-shadow: 0 4px 12px rgba(29, 185, 84, 0.3); } | |
| .gr-textbox { background: #282828 !important; color: #ffffff !important; border: 1px solid #1db954 !important; border-radius: 8px !important; } | |
| """ | |
| with gr.Blocks(css=custom_css, title="Chest X-Ray Detection") as demo: | |
| gr.Markdown("<h1 class='title'>Chest X-Ray Analysis System</h1>", elem_classes="title") | |
| gr.Markdown("Upload a chest X-ray image to classify potential abnormalities using a Vision Transformer (ViT).") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Chest X-Ray", sources=["upload"], elem_classes="input-image") | |
| output_image = gr.Image(label="Annotated Result", elem_classes="output-image") | |
| predict_button = gr.Button("Analyze X-Ray") | |
| output_text = gr.Textbox(label="Prediction Details", lines=2) | |
| predict_button.click( | |
| fn=analyze_chest_xray, | |
| inputs=[image_input], | |
| outputs=[output_image, output_text] | |
| ) | |
| if __name__ == "__main__": | |
| if model and processor: | |
| print("Launching Gradio interface...") | |
| demo.launch(share=True, debug=True) | |
| else: | |
| print("Gradio interface could not be launched because the model failed to load.") | |