import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import numpy as np # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- CLASS_NAMES = [ "Basophil", "Eosinophil", "Erythroblast", "Immature Granulocyte", "Lymphocyte", "Monocyte", "Neutrophil", "Platelet", ] CLASS_INFO = { "Basophil": "Rare granulocyte involved in allergic responses and parasitic defense.", "Eosinophil": "Granulocyte elevated in allergic conditions and helminth infections.", "Erythroblast": "Nucleated erythrocyte precursor; abnormal in peripheral blood.", "Immature Granulocyte": "Premature granulocyte — left shift; indicative of active infection or marrow stress.", "Lymphocyte": "Core adaptive immunity cell; T/B lineage distinction requires further analysis.", "Monocyte": "Large agranulocyte; differentiates into macrophages and dendritic cells.", "Neutrophil": "Primary innate responder to bacterial infection; most abundant leukocyte.", "Platelet": "Anucleate thrombocyte fragment essential for haemostasis.", } preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # --------------------------------------------------------------------------- # Model definition — must mirror training code exactly # --------------------------------------------------------------------------- class LowRankLinear(nn.Module): def __init__(self, in_features: int, out_features: int, rank: int, bias=None): super().__init__() self.rank = rank self.fc_in = nn.Linear(in_features, rank, bias=False) self.fc_out = nn.Linear(rank, out_features, bias=True) if bias is not None: with torch.no_grad(): self.fc_out.bias.copy_(bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc_out(self.fc_in(x)) def build_model(num_classes: int = 8, rank_fc0: int = 256, rank_fc3: int = 256) -> nn.Module: base = models.vgg19(weights=None) base.classifier = nn.Sequential( LowRankLinear(25088, 4096, rank_fc0), nn.ReLU(inplace=True), nn.Dropout(0.5), LowRankLinear(4096, 4096, rank_fc3), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(4096, num_classes), ) return base # --------------------------------------------------------------------------- # Load checkpoint # --------------------------------------------------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = torch.load("vgg19_svd_final.pth", map_location=device, weights_only=False) rank_fc0 = ckpt["rank_fc0"] rank_fc3 = ckpt["rank_fc3"] model = build_model(num_classes=8, rank_fc0=rank_fc0, rank_fc3=rank_fc3) model.load_state_dict(ckpt["model_state_dict"]) model.to(device).eval() # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def predict(image: np.ndarray): if image is None: return None, "Upload or capture an image to begin." img = Image.fromarray(image.astype("uint8"), "RGB") tensor = preprocess(img).unsqueeze(0).to(device) with torch.no_grad(): probs = torch.softmax(model(tensor), dim=1).squeeze().cpu().numpy() top_idx = int(np.argmax(probs)) top_name = CLASS_NAMES[top_idx] top_conf = float(probs[top_idx]) * 100.0 label_dict = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} summary = f"**{top_name}** — {top_conf:.1f}% confidence\n\n{CLASS_INFO[top_name]}" return label_dict, summary # --------------------------------------------------------------------------- # Interface # --------------------------------------------------------------------------- with gr.Blocks( theme=gr.themes.Base( primary_hue=gr.themes.colors.slate, neutral_hue=gr.themes.colors.zinc, font=gr.themes.GoogleFont("Inter"), ), css=""" .gradio-container { max-width: 900px; margin: 0 auto; } #title { border-bottom: 1px solid #e5e7eb; padding-bottom: 1.25rem; margin-bottom: 1.5rem; } #title h1 { font-size: 1.4rem; font-weight: 600; color: #111; letter-spacing: -0.02em; } #title p { font-size: 0.875rem; color: #6b7280; margin-top: 0.25rem; } """, title="Blood Cell Classifier", ) as demo: with gr.Column(elem_id="title"): gr.HTML("""
VGG19 + SVD low-rank compression | BloodMNIST | 8-class
""") with gr.Row(equal_height=True): with gr.Column(scale=1): image_input = gr.Image( label="Input image", type="numpy", sources=["upload", "webcam", "clipboard"], height=280, ) run_btn = gr.Button("Run inference", variant="primary") with gr.Column(scale=1): label_output = gr.Label(label="Class probabilities", num_top_classes=8) summary_output = gr.Markdown(value="Upload or capture an image to begin.") run_btn.click(fn=predict, inputs=image_input, outputs=[label_output, summary_output]) image_input.change(fn=predict, inputs=image_input, outputs=[label_output, summary_output]) gr.HTML("""Applied Statistics in AI · Cairo University 2026 · VGG19-SVD · Macro F1 98.77%
""") demo.launch()