Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| import requests | |
| import os | |
| from typing import Tuple, Dict | |
| # CustomViT model definition | |
| class PatchEmbedding(nn.Module): | |
| def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.n_patches = (img_size // patch_size) ** 2 | |
| self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| def forward(self, x): | |
| x = self.proj(x) | |
| x = x.flatten(2) | |
| x = x.transpose(1, 2) | |
| return x | |
| class Attention(nn.Module): | |
| def __init__(self, dim, n_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.scale = (dim // n_heads) ** -0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(0) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
| self.norm2 = nn.LayerNorm(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(dim, mlp_hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(drop), | |
| nn.Linear(mlp_hidden_dim, dim), | |
| nn.Dropout(drop) | |
| ) | |
| def forward(self, x): | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class CustomViT(nn.Module): | |
| def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=2, embed_dim=768, depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.): | |
| super().__init__() | |
| self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim)) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock(embed_dim, n_heads, mlp_ratio, qkv_bias, drop_rate, drop_rate) | |
| for _ in range(depth) | |
| ]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.head = nn.Linear(embed_dim, num_classes) | |
| def forward(self, x): | |
| B = x.shape[0] | |
| x = self.patch_embed(x) | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| x = x + self.pos_embed | |
| x = self.pos_drop(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.norm(x) | |
| x = x[:, 0] | |
| x = self.head(x) | |
| return x | |
| # Helper functions | |
| def load_model(model_path: str) -> CustomViT: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = CustomViT(num_classes=2) | |
| state_dict = torch.load(model_path, map_location=device) | |
| # Remove 'module.' prefix if present | |
| if all(k.startswith('module.') for k in state_dict.keys()): | |
| state_dict = {k[7:]: v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def preprocess_image(image: np.ndarray) -> torch.Tensor: | |
| # Convert numpy array to PIL Image | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| def predict_image(image: np.ndarray, model: CustomViT) -> Tuple[np.ndarray, Dict[str, float]]: | |
| device = next(model.parameters()).device | |
| # Preprocess the image | |
| image_tensor = preprocess_image(image) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(image_tensor.to(device)) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] | |
| # Create visualization | |
| visualization = image.copy() | |
| height, width = visualization.shape[:2] | |
| # Add prediction overlay | |
| result = "Leprosy" if probabilities[0] > probabilities[1] else "No Leprosy" | |
| confidence = float(probabilities[0] if result == "Leprosy" else probabilities[1]) | |
| # Add text to image | |
| color = (0, 0, 255) if result == "Leprosy" else (0, 255, 0) | |
| cv2.putText(visualization, f"{result}: {confidence:.2%}", | |
| (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) | |
| # Convert BGR to RGB for Gradio | |
| visualization = cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB) | |
| # Prepare labels dictionary | |
| labels = { | |
| "Leprosy": float(probabilities[0]), | |
| "No Leprosy": float(probabilities[1]) | |
| } | |
| return visualization, labels | |
| # Download example images | |
| file_urls = [ | |
| 'https://www.dropbox.com/scl/fi/onrg1u9tqegh64nsfmxgr/lp2.jpg?rlkey=2vgw5n6abqmyismg16mdd1v3n&dl=1', | |
| 'https://www.dropbox.com/scl/fi/xq103ic7ovuuei3l9e8jf/lp1.jpg?rlkey=g7d9khyyc6wplv0ljd4mcha60&dl=1', | |
| 'https://www.dropbox.com/scl/fi/fagkh3gnio2pefdje7fb9/Non_Leprosy_210823_86_jpg.rf.5bb80a7704ecc6c8615574cad5d074c5.jpg?rlkey=ks8afue5gsx5jqvxj3u9mbjmg&dl=1', | |
| ] | |
| def download_example_images(): | |
| examples = [] | |
| for i, url in enumerate(file_urls): | |
| filename = f"example_{i}.jpg" | |
| if not os.path.exists(filename): | |
| response = requests.get(url) | |
| with open(filename, 'wb') as f: | |
| f.write(response.content) | |
| examples.append(filename) | |
| return examples | |
| # Main Gradio interface | |
| def create_gradio_interface(): | |
| # Load the model | |
| model = load_model('best_custom_vit_mo50.pth') | |
| # Create inference function | |
| def inference(image): | |
| return predict_image(image, model) | |
| # Download example images | |
| examples = download_example_images() | |
| # Create Gradio interface | |
| interface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.Image(), | |
| outputs=[ | |
| gr.Image(label="Prediction Visualization"), | |
| gr.Label(label="Classification Probabilities") | |
| ], | |
| title="Leprosy Detection using Vision Transformer", | |
| description="Upload an image to detect signs of leprosy using a Vision Transformer model.", | |
| examples=examples, | |
| cache_examples=False | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_gradio_interface() | |
| interface.launch() | |