import torch import torch.nn as nn import torch.nn.functional as F import cv2 import numpy as np from PIL import Image from torchvision import models, transforms class CBAM(nn.Module): def __init__(self, channels, reduction=16): super(CBAM, self).__init__() self.ca = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // reduction, 1, bias=False), nn.ReLU(), nn.Conv2d(channels // reduction, channels, 1, bias=False) ) self.sa = nn.Sequential( nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False), nn.Sigmoid() ) self.ca_sigmoid = nn.Sigmoid() def forward(self, x): x = x * self.ca_sigmoid(self.ca(x)) avg_out = torch.mean(x, dim=1, keepdim=True); max_out, _ = torch.max(x, dim=1, keepdim=True) x = x * self.sa(torch.cat([avg_out, max_out], dim=1)) return x class ArcMarginProduct(nn.Module): def __init__(self, in_features, out_features, s=35.0, m=0.50): super(ArcMarginProduct, self).__init__() self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) self.s = s def forward(self, input): cosine = F.linear(F.normalize(input), F.normalize(self.weight)) return cosine * self.s class TongueArcResNet(nn.Module): def __init__(self, num_classes=3): super().__init__() self.backbone = models.resnet18(weights=None) self.features = nn.Sequential(*list(self.backbone.children())[:-2]) self.attention = CBAM(512) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.arcface = ArcMarginProduct(512, num_classes, s=35) def forward(self, x): x = self.features(x) x = self.attention(x) features = self.avgpool(x).flatten(1) return self.arcface(features) class TongueModelWrapper: def __init__(self, model_path, num_classes=3): self.device = torch.device("cpu") self.model = TongueArcResNet(num_classes=num_classes) torch.serialization.add_safe_globals([np._core.multiarray.scalar]) checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) if isinstance(checkpoint, dict) and 'model_state' in checkpoint: self.model.load_state_dict(checkpoint['model_state']) print(f"成功載入權重!模型訓練指標:AUC={checkpoint.get('auc', 0):.4f}") else: self.model.load_state_dict(checkpoint) self.model.eval() self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def preprocess(self, img_array): img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) img_gray = cv2.resize(img_gray, (512, 512)) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) ch_clahe = clahe.apply(img_gray) ch_lap = np.absolute(cv2.Laplacian(img_gray, cv2.CV_64F, ksize=3)) ch_lap = cv2.normalize(ch_lap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) combined = np.stack([img_gray, ch_clahe, ch_lap], axis=-1) return Image.fromarray(combined) def predict(self, img_array): if img_array is None: return None processed_img = self.preprocess(img_array) input_tensor = self.transform(processed_img).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(input_tensor) probs = torch.softmax(outputs, dim=1).numpy()[0] return {"NHC(健康人)":float(probs[0]), "DES (一般乾眼)": float(probs[1]), "SJS (乾燥症)": float(probs[2])}