Tongue_Classification / tongue_model.py
PinHsuan's picture
Update tongue_model.py
f59fb67 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from PIL import Image
from torchvision import models, transforms
class CBAM(nn.Module):
def __init__(self, channels, reduction=16):
super(CBAM, self).__init__()
self.ca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(channels // reduction, channels, 1, bias=False)
)
self.sa = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
nn.Sigmoid()
)
self.ca_sigmoid = nn.Sigmoid()
def forward(self, x):
x = x * self.ca_sigmoid(self.ca(x))
avg_out = torch.mean(x, dim=1, keepdim=True); max_out, _ = torch.max(x, dim=1, keepdim=True)
x = x * self.sa(torch.cat([avg_out, max_out], dim=1))
return x
class ArcMarginProduct(nn.Module):
def __init__(self, in_features, out_features, s=35.0, m=0.50):
super(ArcMarginProduct, self).__init__()
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.s = s
def forward(self, input):
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
return cosine * self.s
class TongueArcResNet(nn.Module):
def __init__(self, num_classes=3):
super().__init__()
self.backbone = models.resnet18(weights=None)
self.features = nn.Sequential(*list(self.backbone.children())[:-2])
self.attention = CBAM(512)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.arcface = ArcMarginProduct(512, num_classes, s=35)
def forward(self, x):
x = self.features(x)
x = self.attention(x)
features = self.avgpool(x).flatten(1)
return self.arcface(features)
class TongueModelWrapper:
def __init__(self, model_path, num_classes=3):
self.device = torch.device("cpu")
self.model = TongueArcResNet(num_classes=num_classes)
torch.serialization.add_safe_globals([np._core.multiarray.scalar])
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
if isinstance(checkpoint, dict) and 'model_state' in checkpoint:
self.model.load_state_dict(checkpoint['model_state'])
print(f"成功載入權重!模型訓練指標:AUC={checkpoint.get('auc', 0):.4f}")
else:
self.model.load_state_dict(checkpoint)
self.model.eval()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def preprocess(self, img_array):
img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
img_gray = cv2.resize(img_gray, (512, 512))
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
ch_clahe = clahe.apply(img_gray)
ch_lap = np.absolute(cv2.Laplacian(img_gray, cv2.CV_64F, ksize=3))
ch_lap = cv2.normalize(ch_lap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
combined = np.stack([img_gray, ch_clahe, ch_lap], axis=-1)
return Image.fromarray(combined)
def predict(self, img_array):
if img_array is None: return None
processed_img = self.preprocess(img_array)
input_tensor = self.transform(processed_img).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(input_tensor)
probs = torch.softmax(outputs, dim=1).numpy()[0]
return {"NHC(健康人)":float(probs[0]), "DES (一般乾眼)": float(probs[1]), "SJS (乾燥症)": float(probs[2])}