Spaces:
Sleeping
Sleeping
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| # ----------------------------- | |
| # 1. Load the pretrained model | |
| # ----------------------------- | |
| model_name = "dima806/chest_xray_pneumonia_detection" | |
| model = AutoModelForImageClassification.from_pretrained(model_name) | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| model.eval() | |
| # ----------------------------- | |
| # 2. Prediction function | |
| # ----------------------------- | |
| def predict(image): | |
| # Ensure image is in RGB (some models require 3 channels) | |
| img = image.convert("RGB") | |
| # Preprocess image | |
| inputs = processor(images=img, return_tensors="pt") | |
| # Forward pass | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # Predicted class index (0 = no pneumonia, 1 = pneumonia) | |
| predicted_class_idx = logits.argmax(-1).item() | |
| return "Pneumonia: YES" if predicted_class_idx == 1 else "Pneumonia: NO" | |
| # ----------------------------- | |
| # 3. Gradio interface | |
| # ----------------------------- | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="State-of-the-Art Pneumonia Detector", | |
| description="Upload a chest X-ray. The model predicts if pneumonia is present (YES/NO)." | |
| ) | |
| iface.launch() |