File size: 3,847 Bytes
94c7c73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed6b96b
94c7c73
 
ed6b96b
f59fb67
 
 
 
 
 
 
 
94c7c73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
from torchvision import models, transforms

class CBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAM, self).__init__()
        self.ca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sa = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
            nn.Sigmoid()
        )
        self.ca_sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = x * self.ca_sigmoid(self.ca(x))
        avg_out = torch.mean(x, dim=1, keepdim=True); max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = x * self.sa(torch.cat([avg_out, max_out], dim=1))
        return x

class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=35.0, m=0.50):
        super(ArcMarginProduct, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.s = s
    def forward(self, input):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        return cosine * self.s

class TongueArcResNet(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        self.features = nn.Sequential(*list(self.backbone.children())[:-2])
        self.attention = CBAM(512)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.arcface = ArcMarginProduct(512, num_classes, s=35)
    def forward(self, x):
        x = self.features(x)
        x = self.attention(x)
        features = self.avgpool(x).flatten(1)
        return self.arcface(features)

class TongueModelWrapper:
    def __init__(self, model_path, num_classes=3):
        self.device = torch.device("cpu")
        self.model = TongueArcResNet(num_classes=num_classes)
        torch.serialization.add_safe_globals([np._core.multiarray.scalar])
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        
        if isinstance(checkpoint, dict) and 'model_state' in checkpoint:
            self.model.load_state_dict(checkpoint['model_state'])
            print(f"成功載入權重!模型訓練指標:AUC={checkpoint.get('auc', 0):.4f}")
        else:
            self.model.load_state_dict(checkpoint)
            
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def preprocess(self, img_array):
        img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
        img_gray = cv2.resize(img_gray, (512, 512))
        
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        ch_clahe = clahe.apply(img_gray)
        
        ch_lap = np.absolute(cv2.Laplacian(img_gray, cv2.CV_64F, ksize=3))
        ch_lap = cv2.normalize(ch_lap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        
        combined = np.stack([img_gray, ch_clahe, ch_lap], axis=-1)
        return Image.fromarray(combined)

    def predict(self, img_array):
        if img_array is None: return None
        
        processed_img = self.preprocess(img_array)
        input_tensor = self.transform(processed_img).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(input_tensor)
            probs = torch.softmax(outputs, dim=1).numpy()[0]
        
        return {"NHC(健康人)":float(probs[0]), "DES (一般乾眼)": float(probs[1]), "SJS (乾燥症)": float(probs[2])}