X-vis / app.py
resberry's picture
Update app.py
e30af19 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import os
import time
# =========================
# Image preprocessing
# =========================
transform = transforms.Compose([
transforms.Resize((224, 224)), # Required for ResNet50
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# =========================
# Model definition
# =========================
class FineTunedResNet(nn.Module):
def __init__(self, num_classes=4):
super().__init__()
self.resnet = models.resnet50(
weights=models.ResNet50_Weights.DEFAULT
)
self.resnet.fc = nn.Sequential(
nn.Linear(self.resnet.fc.in_features, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.resnet(x)
# =========================
# Load model
# =========================
MODEL_PATH = "models/final_fine_tuned_resnet50.pth"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
model = FineTunedResNet(num_classes=4)
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
model.to("cpu")
CLASSES = ["🦠 COVID", "🫁 Normal", "🦠 Pneumonia", "🦠 TB"]
# =========================
# Prediction function
# =========================
def predict(image: Image.Image) -> str:
start = time.time()
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
probs = F.softmax(output, dim=1)[0]
top_probs, top_idxs = torch.topk(probs, 3)
elapsed = time.time() - start
result = "Top Predictions:\n\n"
for prob, idx in zip(top_probs, top_idxs):
result += f"{CLASSES[idx]}{prob.item():.4f}\n"
result += f"\n⏱️ Prediction Time: {elapsed:.2f} seconds"
return result
# =========================
# Example images
# =========================
examples = [
["examples/Pneumonia/02009view1_frontal.jpg"],
["examples/Pneumonia/02055view1_frontal.jpg"],
["examples/Pneumonia/03152view1_frontal.jpg"],
["examples/COVID/11547_2020_1200_Fig3_HTML-a.png"],
["examples/COVID/11547_2020_1200_Fig3_HTML-b.png"],
["examples/COVID/11547_2020_1203_Fig1_HTML-b.png"],
["examples/Normal/06bc1cfe-23a0-43a4-a01b-dfa10314bbb0.jpg"],
["examples/Normal/08ae6c0b-d044-4de2-a410-b3cf8dc65868.jpg"],
["examples/Normal/IM-0178-0001.jpeg"]
]
# =========================
# Visualization images
# =========================
visualization_images = [
"pictures/1.png",
"pictures/2.png",
"pictures/3.png",
"pictures/4.png",
"pictures/5.png"
]
def display_visualizations():
return [Image.open(path) for path in visualization_images]
# =========================
# Gradio interfaces
# =========================
prediction_interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
outputs=gr.Textbox(label="Prediction Result"),
examples=examples,
cache_examples=False, # IMPORTANT for HF Spaces
title="Lung Disease Detection XVI",
description="""
Upload a chest X-ray image to detect:
🦠 COVID-19 • 🦠 Pneumonia • 🫁 Normal • 🦠 Tuberculosis
"""
)
visualization_interface = gr.Interface(
fn=display_visualizations,
inputs=None,
outputs=[
gr.Image(type="pil", label=f"Visualization {i+1}")
for i in range(len(visualization_images))
],
title="Model Performance Visualizations"
)
app = gr.TabbedInterface(
interface_list=[prediction_interface, visualization_interface],
tab_names=["Predict", "Model Performance"]
)
# =========================
# Launch (HF Spaces safe)
# =========================
app.launch()