File size: 7,183 Bytes
3f574b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    SegformerImageProcessor,
    SegformerForSemanticSegmentation,
)
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt

# --- 本地模型路径配置 (确保您已将文件下载到以下路径) ---
CLS_MODEL_PATH = "models/vit_base"
SEG_MODEL_PATH = "models/segformer_b1"
DET_MODEL_PATH = "yolov8n.pt"  # YOLOv8 会自动管理下载,我们恢复到简单ID

# --- 1. 图像分类模型 (ViT) ---
try:
    # 从本地路径加载
    cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_PATH, use_fast=True)
    cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_PATH)
    cls_model.eval()

    cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)]

    print("✅ Classification model (ViT) loaded from local path.")
except Exception as e:
    # 如果本地路径失败,尝试远程ID作为回退(仅在开发阶段有用)
    CLS_MODEL_ID = "google/vit-base-patch16-224"
    try:
        cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_ID, use_fast=True)
        cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_ID)
        cls_model.eval()
        cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)]
        print("⚠️ Classification model loaded from remote Hub (Fallback).")
    except Exception as e_fb:
        print(f"❌ Error loading classification model: {e}")
        cls_model = None
        cls_labels = []


def cls_predict(input_img: Image.Image) -> dict:
    """图像分类预测函数"""
    if not cls_model: return {"Error": "Classification model not loaded."}

    inputs = cls_processor(images=input_img, return_tensors="pt")

    with torch.no_grad():
        outputs = cls_model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)

    top5_prob, top5_indices = torch.topk(probabilities, 5)

    results = {}
    for i in range(5):
        label = cls_labels[top5_indices[0][i].item()]
        prob = top5_prob[0][i].item()
        results[label] = round(prob, 4)

    return results


# --- 2. 语义分割模型 (SegFormer) ---
try:
    # 从本地路径加载
    seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_PATH)
    seg_model = SegformerForSemanticSegmentation.from_pretrained(
        SEG_MODEL_PATH,
        use_safetensors=False,  # 确保使用 pytorch_model.bin
        trust_remote_code=True
    )
    seg_model.eval()

    # 获取 SegFormer (ADE20K) 的类别标签
    seg_labels = [label for label in seg_model.config.id2label.values()]
    SEG_CLASSES = len(seg_labels)

    # 创建固定的颜色映射 (Matplotlib HSV Colormap)
    cmap = plt.cm.get_cmap('hsv', SEG_CLASSES)
    COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3]
    COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8)

    COLOR_MAP_DICT = {}
    for i, label in enumerate(seg_labels):
        rgb_color = COLOR_MAP_int[i].tolist()
        hex_color = '#%02x%02x%02x' % tuple(rgb_color)
        COLOR_MAP_DICT[label] = hex_color

    COLOR_MAP = COLOR_MAP_int

    print("✅ Segmentation model (SegFormer) loaded from local path.")
except Exception as e:
    # 远程回退加载
    SEG_MODEL_ID = "nvidia/segformer-b1-finetuned-ade-512-512"
    try:
        seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_ID)
        seg_model = SegformerForSemanticSegmentation.from_pretrained(SEG_MODEL_ID, use_safetensors=False,
                                                                     trust_remote_code=True)
        seg_model.eval()
        seg_labels = [label for label in seg_model.config.id2label.values()]
        SEG_CLASSES = len(seg_labels)
        cmap = plt.cm.get_cmap('hsv', SEG_CLASSES)
        COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3]
        COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8)
        COLOR_MAP_DICT = {}
        for i, label in enumerate(seg_labels):
            rgb_color = COLOR_MAP_int[i].tolist()
            hex_color = '#%02x%02x%02x' % tuple(rgb_color)
            COLOR_MAP_DICT[label] = hex_color
        COLOR_MAP = COLOR_MAP_int
        print("⚠️ Segmentation model loaded from remote Hub (Fallback).")
    except Exception as e_fb:
        print(f"❌ Error loading segmentation model: {e}")
        seg_model = None
        seg_labels = []
        COLOR_MAP_DICT = {"Error": "#000000"}
        SEG_CLASSES = 0


def seg_predict(input_img: Image.Image) -> Image.Image:
    """语义分割预测函数"""
    if not seg_model: return input_img

    original_size = input_img.size

    inputs = seg_processor(images=input_img, return_tensors="pt")

    with torch.no_grad():
        outputs = seg_model(**inputs)

    logits = outputs.logits.cpu()

    predicted_segmentation = F.interpolate(
        logits,
        size=original_size[::-1],
        mode='bilinear',
        align_corners=False
    ).argmax(dim=1)[0]

    # 可视化:创建彩色掩码
    segmentation_np = predicted_segmentation.numpy().astype(np.uint8)

    # 映射到颜色
    if SEG_CLASSES > 0:
        colored_mask = COLOR_MAP[segmentation_np % SEG_CLASSES]
    else:
        colored_mask = np.zeros((*segmentation_np.shape, 3), dtype=np.uint8)

    mask_img = Image.fromarray(colored_mask).convert("RGBA")

    # 叠加到原图上 (使用 50% 透明度)
    input_img = input_img.convert("RGBA")
    output_img = Image.blend(input_img, mask_img, alpha=0.5).convert("RGB")

    return output_img


# --- 3. 目标检测模型 (YOLOv8) ---
try:
    # 保持使用 ID,让 YOLOv8 自动管理下载和缓存
    det_model = YOLO(DET_MODEL_PATH)

    # 获取 YOLOv8 (COCO) 的类别标签
    det_labels_dict = det_model.names
    det_labels = {str(k): v for k, v in det_labels_dict.items()}

    print("✅ Detection model (YOLOv8n) loaded.")
except Exception as e:
    print(f"❌ Error loading detection model: {e}")
    det_model = None
    det_labels = {}


def det_predict(input_img: Image.Image, conf: float) -> Image.Image:
    """目标检测预测函数"""
    if not det_model: return input_img

    results = det_model.predict(
        source=input_img,
        conf=conf,
        iou=0.5,
        verbose=False
    )

    plotted_image_np = results[0].plot()
    return Image.fromarray(plotted_image_np)


# --- 导出标签列表供 Gradio 界面使用 ---
ALL_CLS_LABELS = cls_labels if 'cls_labels' in locals() else ["Classification labels not available."]
ALL_DET_LABELS = det_labels if 'det_labels' in locals() else {"Error": "Detection labels not available."}
ALL_SEG_COLOR_MAP = COLOR_MAP_DICT if 'COLOR_MAP_DICT' in locals() else {"Error": "#000000"}
ALL_SEG_LABELS = seg_labels if 'seg_labels' in locals() else ["Segmentation labels not available."]