GD-VGG19 / app.py
mo-456's picture
Update app.py
6d5d82d verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CLASS_NAMES = [
"Basophil", "Eosinophil", "Erythroblast",
"Immature Granulocyte", "Lymphocyte",
"Monocyte", "Neutrophil", "Platelet",
]
CLASS_INFO = {
"Basophil": "Rare granulocyte involved in allergic responses and parasitic defense.",
"Eosinophil": "Granulocyte elevated in allergic conditions and helminth infections.",
"Erythroblast": "Nucleated erythrocyte precursor; abnormal in peripheral blood.",
"Immature Granulocyte": "Premature granulocyte — left shift; indicative of active infection or marrow stress.",
"Lymphocyte": "Core adaptive immunity cell; T/B lineage distinction requires further analysis.",
"Monocyte": "Large agranulocyte; differentiates into macrophages and dendritic cells.",
"Neutrophil": "Primary innate responder to bacterial infection; most abundant leukocyte.",
"Platelet": "Anucleate thrombocyte fragment essential for haemostasis.",
}
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ---------------------------------------------------------------------------
# Model definition — must mirror training code exactly
# ---------------------------------------------------------------------------
class LowRankLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, rank: int, bias=None):
super().__init__()
self.rank = rank
self.fc_in = nn.Linear(in_features, rank, bias=False)
self.fc_out = nn.Linear(rank, out_features, bias=True)
if bias is not None:
with torch.no_grad():
self.fc_out.bias.copy_(bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc_out(self.fc_in(x))
def build_model(num_classes: int = 8, rank_fc0: int = 256, rank_fc3: int = 256) -> nn.Module:
base = models.vgg19(weights=None)
base.classifier = nn.Sequential(
LowRankLinear(25088, 4096, rank_fc0),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
LowRankLinear(4096, 4096, rank_fc3),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, num_classes),
)
return base
# ---------------------------------------------------------------------------
# Load checkpoint
# ---------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load("vgg19_svd_final.pth", map_location=device, weights_only=False)
rank_fc0 = ckpt["rank_fc0"]
rank_fc3 = ckpt["rank_fc3"]
model = build_model(num_classes=8, rank_fc0=rank_fc0, rank_fc3=rank_fc3)
model.load_state_dict(ckpt["model_state_dict"])
model.to(device).eval()
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def predict(image: np.ndarray):
if image is None:
return None, "Upload or capture an image to begin."
img = Image.fromarray(image.astype("uint8"), "RGB")
tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
probs = torch.softmax(model(tensor), dim=1).squeeze().cpu().numpy()
top_idx = int(np.argmax(probs))
top_name = CLASS_NAMES[top_idx]
top_conf = float(probs[top_idx]) * 100.0
label_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
summary = f"**{top_name}** — {top_conf:.1f}% confidence\n\n{CLASS_INFO[top_name]}"
return label_dict, summary
# ---------------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------------
with gr.Blocks(
theme=gr.themes.Base(
primary_hue=gr.themes.colors.slate,
neutral_hue=gr.themes.colors.zinc,
font=gr.themes.GoogleFont("Inter"),
),
css="""
.gradio-container { max-width: 900px; margin: 0 auto; }
#title { border-bottom: 1px solid #e5e7eb; padding-bottom: 1.25rem; margin-bottom: 1.5rem; }
#title h1 { font-size: 1.4rem; font-weight: 600; color: #111; letter-spacing: -0.02em; }
#title p { font-size: 0.875rem; color: #6b7280; margin-top: 0.25rem; }
""",
title="Blood Cell Classifier",
) as demo:
with gr.Column(elem_id="title"):
gr.HTML("""
<h1>Blood Cell Classifier</h1>
<p>VGG19 + SVD low-rank compression &nbsp;|&nbsp; BloodMNIST &nbsp;|&nbsp; 8-class</p>
""")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
image_input = gr.Image(
label="Input image",
type="numpy",
sources=["upload", "webcam", "clipboard"],
height=280,
)
run_btn = gr.Button("Run inference", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(label="Class probabilities", num_top_classes=8)
summary_output = gr.Markdown(value="Upload or capture an image to begin.")
run_btn.click(fn=predict, inputs=image_input, outputs=[label_output, summary_output])
image_input.change(fn=predict, inputs=image_input, outputs=[label_output, summary_output])
gr.HTML("""
<p style="font-size:0.75rem;color:#9ca3af;margin-top:1.5rem;border-top:1px solid #f3f4f6;padding-top:1rem;">
Applied Statistics in AI &nbsp;&middot;&nbsp; Cairo University 2026 &nbsp;&middot;&nbsp;
VGG19-SVD &nbsp;&middot;&nbsp; Macro F1 98.77%
</p>
""")
demo.launch()