File size: 1,973 Bytes
6cfeaf6
 
 
446f931
 
6cfeaf6
 
446f931
3bd7ef4
0b5083a
6cfeaf6
 
 
 
 
 
 
 
 
 
 
 
 
 
3bd7ef4
6cfeaf6
3bd7ef4
6cfeaf6
 
 
 
 
 
cb2dd0b
d6e3eaf
6cfeaf6
 
 
96f73e3
cb2dd0b
96f73e3
0ee701d
96f73e3
761d43f
6cfeaf6
 
 
 
 
 
 
138f164
6cfeaf6
006ac48
 
 
6cfeaf6
 
adc82e5
c207901
dd3a0da
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from torch import nn
import torch
# import os
# import tempfile
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from transformers import ConvNextV2Config, ConvNextV2ForImageClassification
import numpy as np

# preprocessing
transforms = A.Compose([
    A.LongestMaxSize(384),
    A.CLAHE(),
    A.Normalize(normalization='image'),
    A.PadIfNeeded(384, 384, border_mode=0, value=(0)),
    ToTensorV2()
])

# model
class PrHu_model(nn.Module):
    def __init__(self):
        super().__init__()
        self.configuration = ConvNextV2Config(num_channels=1, drop_path_rate=0, image_size=384, num_labels=1,
                                                           depths=[2, 2, 6, 2], hidden_sizes=[16, 32, 64, 128])
        self.model = ConvNextV2ForImageClassification(self.configuration)

    def forward(self, x):
        return self.model(x).logits
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = nn.DataParallel(PrHu_model()).to(device)
model.load_state_dict(torch.load('PrHu_model.pth', map_location=device, weights_only=True))
model.eval()

def inference(img_dir):
   image = np.array(Image.open(str(img_dir)).convert('L'))
   image = transforms(image=image)['image']
   image = image.float().to(device)
   with torch.inference_mode():
        out = model(image.unsqueeze(0)).item()
   out = out > 0
   return "Fracture +" if out else "Fracture -"

examples = ["NF1.jpg", "NF2.jpg", "NF3.jpg", "F1.jpg", "F2.jpg", "F3.jpg"]

#UI
iface = gr.Interface(
    fn=inference,
    inputs=[
        gr.Image(label="Upload or Select Input Image", type="filepath"),
    ],
    outputs=[
        gr.Textbox(label="Classification Result"),  # Display the classification result
    ],
    title="Proximal Humerus Fracture Detection",
    description="Upload an image, and get the classification result.",
    examples=examples,  # Add example inputs,
    cache_examples=True
)
iface.launch()