Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| class CBAM(nn.Module): | |
| def __init__(self, channels, reduction=16): | |
| super(CBAM, self).__init__() | |
| self.ca = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(channels, channels // reduction, 1, bias=False), | |
| nn.ReLU(), | |
| nn.Conv2d(channels // reduction, channels, 1, bias=False) | |
| ) | |
| self.sa = nn.Sequential( | |
| nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| self.ca_sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| x = x * self.ca_sigmoid(self.ca(x)) | |
| avg_out = torch.mean(x, dim=1, keepdim=True); max_out, _ = torch.max(x, dim=1, keepdim=True) | |
| x = x * self.sa(torch.cat([avg_out, max_out], dim=1)) | |
| return x | |
| class ArcMarginProduct(nn.Module): | |
| def __init__(self, in_features, out_features, s=35.0, m=0.50): | |
| super(ArcMarginProduct, self).__init__() | |
| self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) | |
| nn.init.xavier_uniform_(self.weight) | |
| self.s = s | |
| def forward(self, input): | |
| cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | |
| return cosine * self.s | |
| class TongueArcResNet(nn.Module): | |
| def __init__(self, num_classes=3): | |
| super().__init__() | |
| self.backbone = models.resnet18(weights=None) | |
| self.features = nn.Sequential(*list(self.backbone.children())[:-2]) | |
| self.attention = CBAM(512) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.arcface = ArcMarginProduct(512, num_classes, s=35) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.attention(x) | |
| features = self.avgpool(x).flatten(1) | |
| return self.arcface(features) | |
| class TongueModelWrapper: | |
| def __init__(self, model_path, num_classes=3): | |
| self.device = torch.device("cpu") | |
| self.model = TongueArcResNet(num_classes=num_classes) | |
| torch.serialization.add_safe_globals([np._core.multiarray.scalar]) | |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) | |
| if isinstance(checkpoint, dict) and 'model_state' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model_state']) | |
| print(f"成功載入權重!模型訓練指標:AUC={checkpoint.get('auc', 0):.4f}") | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| self.model.eval() | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def preprocess(self, img_array): | |
| img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| img_gray = cv2.resize(img_gray, (512, 512)) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| ch_clahe = clahe.apply(img_gray) | |
| ch_lap = np.absolute(cv2.Laplacian(img_gray, cv2.CV_64F, ksize=3)) | |
| ch_lap = cv2.normalize(ch_lap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) | |
| combined = np.stack([img_gray, ch_clahe, ch_lap], axis=-1) | |
| return Image.fromarray(combined) | |
| def predict(self, img_array): | |
| if img_array is None: return None | |
| processed_img = self.preprocess(img_array) | |
| input_tensor = self.transform(processed_img).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_tensor) | |
| probs = torch.softmax(outputs, dim=1).numpy()[0] | |
| return {"NHC(健康人)":float(probs[0]), "DES (一般乾眼)": float(probs[1]), "SJS (乾燥症)": float(probs[2])} |