from transformers import AutoModelForImageClassification, AutoImageProcessor import torch from PIL import Image import gradio as gr model_name = "codewithdark/vit-chest-xray" model = AutoModelForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) model.eval() labels = ['Cardiomegaly', 'Edema', 'Consolidation', 'No Finding', 'Pneumonia'] target_labels = ['Pneumonia', 'Consolidation', 'Edema'] target_idxs = [labels.index(lbl) for lbl in target_labels] def predict(image): # Make sure image is RGB if image.mode != "RGB": image = image.convert("RGB") ``` # Process the image properly inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits # Keep batch dimension for safety probs = torch.sigmoid(logits)[0] # [batch, num_labels] -> [num_labels] detected = [] results = [] for idx, lbl in zip(target_idxs, target_labels): prob = probs[idx].item() status = "YES" if prob > 0.5 else "NO" results.append(f"{lbl}: {status} ({prob:.2f})") if status == "YES": detected.append(lbl) if detected: summary = f"⚠️ Patient shows signs of: {', '.join(detected)}." else: summary = "✅ Patient appears healthy — no major lung issues detected." return "\n".join(results + ["\n" + summary]) ``` iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Chest X-ray Disease Detector", description="Upload a chest X-ray to detect Pneumonia, Consolidation, and Edema. Gives clear patient health summary." ) iface.launch()