AmirV97's picture
display fixed
761d43f
import gradio as gr
from torch import nn
import torch
# import os
# import tempfile
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from transformers import ConvNextV2Config, ConvNextV2ForImageClassification
import numpy as np
# preprocessing
transforms = A.Compose([
A.LongestMaxSize(384),
A.CLAHE(),
A.Normalize(normalization='image'),
A.PadIfNeeded(384, 384, border_mode=0, value=(0)),
ToTensorV2()
])
# model
class PrHu_model(nn.Module):
def __init__(self):
super().__init__()
self.configuration = ConvNextV2Config(num_channels=1, drop_path_rate=0, image_size=384, num_labels=1,
depths=[2, 2, 6, 2], hidden_sizes=[16, 32, 64, 128])
self.model = ConvNextV2ForImageClassification(self.configuration)
def forward(self, x):
return self.model(x).logits
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = nn.DataParallel(PrHu_model()).to(device)
model.load_state_dict(torch.load('PrHu_model.pth', map_location=device, weights_only=True))
model.eval()
def inference(img_dir):
image = np.array(Image.open(str(img_dir)).convert('L'))
image = transforms(image=image)['image']
image = image.float().to(device)
with torch.inference_mode():
out = model(image.unsqueeze(0)).item()
out = out > 0
return "Fracture +" if out else "Fracture -"
examples = ["NF1.jpg", "NF2.jpg", "NF3.jpg", "F1.jpg", "F2.jpg", "F3.jpg"]
#UI
iface = gr.Interface(
fn=inference,
inputs=[
gr.Image(label="Upload or Select Input Image", type="filepath"),
],
outputs=[
gr.Textbox(label="Classification Result"), # Display the classification result
],
title="Proximal Humerus Fracture Detection",
description="Upload an image, and get the classification result.",
examples=examples, # Add example inputs,
cache_examples=True
)
iface.launch()