PinHsuan commited on
Commit
94c7c73
·
verified ·
1 Parent(s): 9167ef6

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +20 -0
  2. best_model.pt +3 -0
  3. requirements.txt +6 -0
  4. tongue_model.py +89 -0
app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from tongue_model import TongueModelWrapper
3
+
4
+ wrapper = TongueModelWrapper(model_path="best_model.pt")
5
+
6
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
7
+ gr.Markdown("# 🏥 中醫舌象自動診斷系統")
8
+
9
+ with gr.Row():
10
+ with gr.Column():
11
+ input_img = gr.Image(label="上傳舌象照片")
12
+ btn = gr.Button("🚀 開始分析", variant="primary")
13
+
14
+ with gr.Column():
15
+ output_label = gr.Label(label="預測結果機率")
16
+
17
+ btn.click(fn=wrapper.predict, inputs=input_img, outputs=output_label)
18
+
19
+ if __name__ == "__main__":
20
+ demo.launch()
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3326eba752731c60c77defabf73e8e21bb6a23ced04285462e8c67d5ec4d71df
3
+ size 46996455
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ opencv-python-headless
5
+ numpy
6
+ Pillow
tongue_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ from torchvision import models, transforms
8
+
9
+ # --- 模型架構定義 ---
10
+ class CBAM(nn.Module):
11
+ def __init__(self, channels, reduction=16):
12
+ super(CBAM, self).__init__()
13
+ self.ca = nn.Sequential(
14
+ nn.AdaptiveAvgPool2d(1),
15
+ nn.Conv2d(channels, channels // reduction, 1, bias=False),
16
+ nn.ReLU(),
17
+ nn.Conv2d(channels // reduction, channels, 1, bias=False)
18
+ )
19
+ self.sa = nn.Sequential(
20
+ nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
21
+ nn.Sigmoid()
22
+ )
23
+ self.ca_sigmoid = nn.Sigmoid()
24
+ def forward(self, x):
25
+ x = x * self.ca_sigmoid(self.ca(x))
26
+ avg_out = torch.mean(x, dim=1, keepdim=True); max_out, _ = torch.max(x, dim=1, keepdim=True)
27
+ x = x * self.sa(torch.cat([avg_out, max_out], dim=1))
28
+ return x
29
+
30
+ class ArcMarginProduct(nn.Module):
31
+ def __init__(self, in_features, out_features, s=35.0, m=0.50):
32
+ super(ArcMarginProduct, self).__init__()
33
+ self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
34
+ nn.init.xavier_uniform_(self.weight)
35
+ self.s = s
36
+ def forward(self, input):
37
+ cosine = F.linear(F.normalize(input), F.normalize(self.weight))
38
+ return cosine * self.s
39
+
40
+ class TongueArcResNet(nn.Module):
41
+ def __init__(self, num_classes=3):
42
+ super().__init__()
43
+ self.backbone = models.resnet18(weights=None)
44
+ self.features = nn.Sequential(*list(self.backbone.children())[:-2])
45
+ self.attention = CBAM(512)
46
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
47
+ self.arcface = ArcMarginProduct(512, num_classes, s=35)
48
+ def forward(self, x):
49
+ x = self.features(x)
50
+ x = self.attention(x)
51
+ features = self.avgpool(x).flatten(1)
52
+ return self.arcface(features)
53
+
54
+ # --- 2. 定義預處理與推論類別 ---
55
+ class TongueModelWrapper:
56
+ def __init__(self, model_path, num_classes=2):
57
+ self.device = torch.device("cpu")
58
+ self.model = TongueArcResNet(num_classes=num_classes)
59
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
60
+ self.model.eval()
61
+ self.transform = transforms.Compose([
62
+ transforms.ToTensor(),
63
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
64
+ ])
65
+
66
+ def preprocess(self, img_array):
67
+ img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
68
+ img_gray = cv2.resize(img_gray, (512, 512))
69
+
70
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
71
+ ch_clahe = clahe.apply(img_gray)
72
+
73
+ ch_lap = np.absolute(cv2.Laplacian(img_gray, cv2.CV_64F, ksize=3))
74
+ ch_lap = cv2.normalize(ch_lap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
75
+
76
+ combined = np.stack([img_gray, ch_clahe, ch_lap], axis=-1)
77
+ return Image.fromarray(combined)
78
+
79
+ def predict(self, img_array):
80
+ if img_array is None: return None
81
+
82
+ processed_img = self.preprocess(img_array)
83
+ input_tensor = self.transform(processed_img).unsqueeze(0).to(self.device)
84
+
85
+ with torch.no_grad():
86
+ outputs = self.model(input_tensor)
87
+ probs = torch.softmax(outputs, dim=1).numpy()[0]
88
+
89
+ return {"NHC(健康人)":float(probs[0]), "DES (一般乾眼)": float(probs[1]), "SJS (乾燥症)": float(probs[2])}