Spaces:
Sleeping
Sleeping
| # ============================================================ | |
| # Gradio App: ResNet18 Classification + SwinUNet Segmentation | |
| # With Proper AES Encryption for Uploaded Images | |
| # ============================================================ | |
| import os | |
| import io | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| import timm | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from cryptography.fernet import Fernet | |
| # ---------------- Security Setup ---------------- | |
| KEY_PATH = "secret.key" | |
| if not os.path.exists(KEY_PATH): | |
| with open(KEY_PATH, "wb") as f: | |
| f.write(Fernet.generate_key()) | |
| with open(KEY_PATH, "rb") as f: | |
| key = f.read() | |
| fernet = Fernet(key) | |
| def encrypt_bytes(image_bytes): | |
| """Encrypt image bytes and return bytes.""" | |
| return fernet.encrypt(image_bytes) | |
| def decrypt_bytes(encrypted_bytes): | |
| """Decrypt bytes and return PIL image.""" | |
| decrypted = fernet.decrypt(encrypted_bytes) | |
| return Image.open(io.BytesIO(decrypted)) | |
| # ---------------- Device ---------------- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------------- Classification Model ---------------- | |
| class BrainTumorResNet18(nn.Module): | |
| def __init__(self, num_classes=4, pretrained=False): | |
| super().__init__() | |
| self.model = models.resnet18(pretrained=pretrained) | |
| in_features = self.model.fc.in_features | |
| self.model.fc = nn.Sequential( | |
| nn.Dropout(0.5), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| clf_model = BrainTumorResNet18(num_classes=4).to(DEVICE) | |
| clf_model.load_state_dict(torch.load("models/best_resnet18_mri.pth", map_location=DEVICE)) | |
| clf_model.eval() | |
| clf_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.5,), std=(0.5,)) | |
| ]) | |
| CLASS_NAMES = ["glioma", "meningioma", "notumor", "pituitary"] | |
| # ---------------- Segmentation Model ---------------- | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_ch, out_ch): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class SwinUNet(nn.Module): | |
| def __init__(self, encoder_name="swin_small_patch4_window7_224", pretrained=True, num_classes=1): | |
| super().__init__() | |
| self.encoder = timm.create_model(encoder_name, pretrained=pretrained, | |
| features_only=True, out_indices=(0,1,2,3)) | |
| enc_chs = self.encoder.feature_info.channels() | |
| self.up3 = nn.ConvTranspose2d(enc_chs[3], enc_chs[2], 2, stride=2) | |
| self.dec3 = ConvBlock(enc_chs[2]*2, enc_chs[2]) | |
| self.up2 = nn.ConvTranspose2d(enc_chs[2], enc_chs[1], 2, stride=2) | |
| self.dec2 = ConvBlock(enc_chs[1]*2, enc_chs[1]) | |
| self.up1 = nn.ConvTranspose2d(enc_chs[1], enc_chs[0], 2, stride=2) | |
| self.dec1 = ConvBlock(enc_chs[0]*2, enc_chs[0]) | |
| self.final_up = nn.ConvTranspose2d(enc_chs[0], 64, 2, stride=2) | |
| self.final_conv = nn.Sequential( | |
| nn.Conv2d(64, 32, 3, padding=1, bias=False), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(32, num_classes, 1) | |
| ) | |
| def _ensure_nchw(self, feat, expected_ch): | |
| if feat.ndim==4: | |
| if feat.shape[1]==expected_ch: return feat | |
| if feat.shape[-1]==expected_ch: return feat.permute(0,3,1,2).contiguous() | |
| return feat | |
| def forward(self, x): | |
| feats = self.encoder(x) | |
| expected = self.encoder.feature_info.channels() | |
| for i in range(len(feats)): | |
| feats[i] = self._ensure_nchw(feats[i], expected[i]) | |
| f0,f1,f2,f3 = feats | |
| d3 = self.up3(f3) | |
| if d3.shape[-2:] != f2.shape[-2:]: | |
| d3 = nn.functional.interpolate(d3, size=f2.shape[-2:], mode='bilinear', align_corners=False) | |
| d3 = self.dec3(torch.cat([d3,f2], dim=1)) | |
| d2 = self.up2(d3) | |
| if d2.shape[-2:] != f1.shape[-2:]: | |
| d2 = nn.functional.interpolate(d2, size=f1.shape[-2:], mode='bilinear', align_corners=False) | |
| d2 = self.dec2(torch.cat([d2,f1], dim=1)) | |
| d1 = self.up1(d2) | |
| if d1.shape[-2:] != f0.shape[-2:]: | |
| d1 = nn.functional.interpolate(d1, size=f0.shape[-2:], mode='bilinear', align_corners=False) | |
| d1 = self.dec1(torch.cat([d1,f0], dim=1)) | |
| out = self.final_up(d1) | |
| return self.final_conv(out) | |
| seg_model = SwinUNet().to(DEVICE) | |
| seg_model.load_state_dict(torch.load("models/swinunet_best (6).pth", map_location=DEVICE), strict=False) | |
| seg_model.eval() | |
| seg_transform = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor() | |
| ]) | |
| # ---------------- Inference Function ---------------- | |
| def predict(img): | |
| # Convert uploaded image to PIL | |
| pil_img = Image.fromarray(img).convert("RGB") | |
| # ---- Encrypt image in memory ---- | |
| img_bytes = io.BytesIO() | |
| pil_img.save(img_bytes, format="PNG") | |
| encrypted_bytes = encrypt_bytes(img_bytes.getvalue()) | |
| # ---- Decrypt immediately for inference ---- | |
| decrypted_img = decrypt_bytes(encrypted_bytes) | |
| # ---- Classification ---- | |
| x = clf_transform(decrypted_img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = clf_model(x) | |
| probs = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
| pred_class = CLASS_NAMES[np.argmax(probs)] | |
| conf = float(np.max(probs)) | |
| # ---- Segmentation ---- | |
| seg_in = seg_transform(decrypted_img).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| mask = seg_model(seg_in)[0,0].cpu().numpy() | |
| mask = (mask > 0.5).astype(np.uint8) | |
| # Overlay mask | |
| img_np = np.array(decrypted_img.resize((224,224))) | |
| mask_resized = cv2.resize(mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| overlay = img_np.copy() | |
| overlay[mask_resized > 0] = [255, 0, 0] # red | |
| blended = cv2.addWeighted(img_np, 0.7, overlay, 0.3, 0) | |
| return blended, f"Prediction: {pred_class} (conf: {conf:.2f})" | |
| # ---------------- Gradio UI ---------------- | |
| example_images = [ | |
| "images/img1.jpg", | |
| "images/img2.jpg", | |
| "images/img3.jpg", | |
| "images/img4.jpg", | |
| "images/img5.jpg", | |
| "images/img6.jpg", | |
| "images/img7.jpg", | |
| "images/img8.jpg", | |
| "images/img9.jpg", | |
| "images/img10.jpg", | |
| ] | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=[gr.Image(type="numpy"), gr.Textbox()], | |
| title="ONCOSCAN - (Brain Tumor Classification + Segmentation) ", | |
| description="Upload an MRI or click on one of the example images. The app will classify tumor type (ResNet18) and segment tumor region (SwinUNet).", | |
| examples=example_images, | |
| cache_examples=False | |
| ) | |
| demo.launch(debug=True) | |