File size: 7,183 Bytes
3f574b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import os
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
SegformerImageProcessor,
SegformerForSemanticSegmentation,
)
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
# --- 本地模型路径配置 (确保您已将文件下载到以下路径) ---
CLS_MODEL_PATH = "models/vit_base"
SEG_MODEL_PATH = "models/segformer_b1"
DET_MODEL_PATH = "yolov8n.pt" # YOLOv8 会自动管理下载,我们恢复到简单ID
# --- 1. 图像分类模型 (ViT) ---
try:
# 从本地路径加载
cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_PATH, use_fast=True)
cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_PATH)
cls_model.eval()
cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)]
print("✅ Classification model (ViT) loaded from local path.")
except Exception as e:
# 如果本地路径失败,尝试远程ID作为回退(仅在开发阶段有用)
CLS_MODEL_ID = "google/vit-base-patch16-224"
try:
cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_ID, use_fast=True)
cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_ID)
cls_model.eval()
cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)]
print("⚠️ Classification model loaded from remote Hub (Fallback).")
except Exception as e_fb:
print(f"❌ Error loading classification model: {e}")
cls_model = None
cls_labels = []
def cls_predict(input_img: Image.Image) -> dict:
"""图像分类预测函数"""
if not cls_model: return {"Error": "Classification model not loaded."}
inputs = cls_processor(images=input_img, return_tensors="pt")
with torch.no_grad():
outputs = cls_model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
top5_prob, top5_indices = torch.topk(probabilities, 5)
results = {}
for i in range(5):
label = cls_labels[top5_indices[0][i].item()]
prob = top5_prob[0][i].item()
results[label] = round(prob, 4)
return results
# --- 2. 语义分割模型 (SegFormer) ---
try:
# 从本地路径加载
seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_PATH)
seg_model = SegformerForSemanticSegmentation.from_pretrained(
SEG_MODEL_PATH,
use_safetensors=False, # 确保使用 pytorch_model.bin
trust_remote_code=True
)
seg_model.eval()
# 获取 SegFormer (ADE20K) 的类别标签
seg_labels = [label for label in seg_model.config.id2label.values()]
SEG_CLASSES = len(seg_labels)
# 创建固定的颜色映射 (Matplotlib HSV Colormap)
cmap = plt.cm.get_cmap('hsv', SEG_CLASSES)
COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3]
COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8)
COLOR_MAP_DICT = {}
for i, label in enumerate(seg_labels):
rgb_color = COLOR_MAP_int[i].tolist()
hex_color = '#%02x%02x%02x' % tuple(rgb_color)
COLOR_MAP_DICT[label] = hex_color
COLOR_MAP = COLOR_MAP_int
print("✅ Segmentation model (SegFormer) loaded from local path.")
except Exception as e:
# 远程回退加载
SEG_MODEL_ID = "nvidia/segformer-b1-finetuned-ade-512-512"
try:
seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_ID)
seg_model = SegformerForSemanticSegmentation.from_pretrained(SEG_MODEL_ID, use_safetensors=False,
trust_remote_code=True)
seg_model.eval()
seg_labels = [label for label in seg_model.config.id2label.values()]
SEG_CLASSES = len(seg_labels)
cmap = plt.cm.get_cmap('hsv', SEG_CLASSES)
COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3]
COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8)
COLOR_MAP_DICT = {}
for i, label in enumerate(seg_labels):
rgb_color = COLOR_MAP_int[i].tolist()
hex_color = '#%02x%02x%02x' % tuple(rgb_color)
COLOR_MAP_DICT[label] = hex_color
COLOR_MAP = COLOR_MAP_int
print("⚠️ Segmentation model loaded from remote Hub (Fallback).")
except Exception as e_fb:
print(f"❌ Error loading segmentation model: {e}")
seg_model = None
seg_labels = []
COLOR_MAP_DICT = {"Error": "#000000"}
SEG_CLASSES = 0
def seg_predict(input_img: Image.Image) -> Image.Image:
"""语义分割预测函数"""
if not seg_model: return input_img
original_size = input_img.size
inputs = seg_processor(images=input_img, return_tensors="pt")
with torch.no_grad():
outputs = seg_model(**inputs)
logits = outputs.logits.cpu()
predicted_segmentation = F.interpolate(
logits,
size=original_size[::-1],
mode='bilinear',
align_corners=False
).argmax(dim=1)[0]
# 可视化:创建彩色掩码
segmentation_np = predicted_segmentation.numpy().astype(np.uint8)
# 映射到颜色
if SEG_CLASSES > 0:
colored_mask = COLOR_MAP[segmentation_np % SEG_CLASSES]
else:
colored_mask = np.zeros((*segmentation_np.shape, 3), dtype=np.uint8)
mask_img = Image.fromarray(colored_mask).convert("RGBA")
# 叠加到原图上 (使用 50% 透明度)
input_img = input_img.convert("RGBA")
output_img = Image.blend(input_img, mask_img, alpha=0.5).convert("RGB")
return output_img
# --- 3. 目标检测模型 (YOLOv8) ---
try:
# 保持使用 ID,让 YOLOv8 自动管理下载和缓存
det_model = YOLO(DET_MODEL_PATH)
# 获取 YOLOv8 (COCO) 的类别标签
det_labels_dict = det_model.names
det_labels = {str(k): v for k, v in det_labels_dict.items()}
print("✅ Detection model (YOLOv8n) loaded.")
except Exception as e:
print(f"❌ Error loading detection model: {e}")
det_model = None
det_labels = {}
def det_predict(input_img: Image.Image, conf: float) -> Image.Image:
"""目标检测预测函数"""
if not det_model: return input_img
results = det_model.predict(
source=input_img,
conf=conf,
iou=0.5,
verbose=False
)
plotted_image_np = results[0].plot()
return Image.fromarray(plotted_image_np)
# --- 导出标签列表供 Gradio 界面使用 ---
ALL_CLS_LABELS = cls_labels if 'cls_labels' in locals() else ["Classification labels not available."]
ALL_DET_LABELS = det_labels if 'det_labels' in locals() else {"Error": "Detection labels not available."}
ALL_SEG_COLOR_MAP = COLOR_MAP_DICT if 'COLOR_MAP_DICT' in locals() else {"Error": "#000000"}
ALL_SEG_LABELS = seg_labels if 'seg_labels' in locals() else ["Segmentation labels not available."] |