import os import torch import torch.nn.functional as F from PIL import Image from torchvision import transforms from transformers import ( AutoImageProcessor, AutoModelForImageClassification, SegformerImageProcessor, SegformerForSemanticSegmentation, ) import numpy as np from ultralytics import YOLO import matplotlib.pyplot as plt # --- 本地模型路径配置 (确保您已将文件下载到以下路径) --- CLS_MODEL_PATH = "models/vit_base" SEG_MODEL_PATH = "models/segformer_b1" DET_MODEL_PATH = "yolov8n.pt" # YOLOv8 会自动管理下载,我们恢复到简单ID # --- 1. 图像分类模型 (ViT) --- try: # 从本地路径加载 cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_PATH, use_fast=True) cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_PATH) cls_model.eval() cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)] print("✅ Classification model (ViT) loaded from local path.") except Exception as e: # 如果本地路径失败,尝试远程ID作为回退(仅在开发阶段有用) CLS_MODEL_ID = "google/vit-base-patch16-224" try: cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_ID, use_fast=True) cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_ID) cls_model.eval() cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)] print("⚠️ Classification model loaded from remote Hub (Fallback).") except Exception as e_fb: print(f"❌ Error loading classification model: {e}") cls_model = None cls_labels = [] def cls_predict(input_img: Image.Image) -> dict: """图像分类预测函数""" if not cls_model: return {"Error": "Classification model not loaded."} inputs = cls_processor(images=input_img, return_tensors="pt") with torch.no_grad(): outputs = cls_model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) top5_prob, top5_indices = torch.topk(probabilities, 5) results = {} for i in range(5): label = cls_labels[top5_indices[0][i].item()] prob = top5_prob[0][i].item() results[label] = round(prob, 4) return results # --- 2. 语义分割模型 (SegFormer) --- try: # 从本地路径加载 seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_PATH) seg_model = SegformerForSemanticSegmentation.from_pretrained( SEG_MODEL_PATH, use_safetensors=False, # 确保使用 pytorch_model.bin trust_remote_code=True ) seg_model.eval() # 获取 SegFormer (ADE20K) 的类别标签 seg_labels = [label for label in seg_model.config.id2label.values()] SEG_CLASSES = len(seg_labels) # 创建固定的颜色映射 (Matplotlib HSV Colormap) cmap = plt.cm.get_cmap('hsv', SEG_CLASSES) COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3] COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8) COLOR_MAP_DICT = {} for i, label in enumerate(seg_labels): rgb_color = COLOR_MAP_int[i].tolist() hex_color = '#%02x%02x%02x' % tuple(rgb_color) COLOR_MAP_DICT[label] = hex_color COLOR_MAP = COLOR_MAP_int print("✅ Segmentation model (SegFormer) loaded from local path.") except Exception as e: # 远程回退加载 SEG_MODEL_ID = "nvidia/segformer-b1-finetuned-ade-512-512" try: seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_ID) seg_model = SegformerForSemanticSegmentation.from_pretrained(SEG_MODEL_ID, use_safetensors=False, trust_remote_code=True) seg_model.eval() seg_labels = [label for label in seg_model.config.id2label.values()] SEG_CLASSES = len(seg_labels) cmap = plt.cm.get_cmap('hsv', SEG_CLASSES) COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3] COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8) COLOR_MAP_DICT = {} for i, label in enumerate(seg_labels): rgb_color = COLOR_MAP_int[i].tolist() hex_color = '#%02x%02x%02x' % tuple(rgb_color) COLOR_MAP_DICT[label] = hex_color COLOR_MAP = COLOR_MAP_int print("⚠️ Segmentation model loaded from remote Hub (Fallback).") except Exception as e_fb: print(f"❌ Error loading segmentation model: {e}") seg_model = None seg_labels = [] COLOR_MAP_DICT = {"Error": "#000000"} SEG_CLASSES = 0 def seg_predict(input_img: Image.Image) -> Image.Image: """语义分割预测函数""" if not seg_model: return input_img original_size = input_img.size inputs = seg_processor(images=input_img, return_tensors="pt") with torch.no_grad(): outputs = seg_model(**inputs) logits = outputs.logits.cpu() predicted_segmentation = F.interpolate( logits, size=original_size[::-1], mode='bilinear', align_corners=False ).argmax(dim=1)[0] # 可视化:创建彩色掩码 segmentation_np = predicted_segmentation.numpy().astype(np.uint8) # 映射到颜色 if SEG_CLASSES > 0: colored_mask = COLOR_MAP[segmentation_np % SEG_CLASSES] else: colored_mask = np.zeros((*segmentation_np.shape, 3), dtype=np.uint8) mask_img = Image.fromarray(colored_mask).convert("RGBA") # 叠加到原图上 (使用 50% 透明度) input_img = input_img.convert("RGBA") output_img = Image.blend(input_img, mask_img, alpha=0.5).convert("RGB") return output_img # --- 3. 目标检测模型 (YOLOv8) --- try: # 保持使用 ID,让 YOLOv8 自动管理下载和缓存 det_model = YOLO(DET_MODEL_PATH) # 获取 YOLOv8 (COCO) 的类别标签 det_labels_dict = det_model.names det_labels = {str(k): v for k, v in det_labels_dict.items()} print("✅ Detection model (YOLOv8n) loaded.") except Exception as e: print(f"❌ Error loading detection model: {e}") det_model = None det_labels = {} def det_predict(input_img: Image.Image, conf: float) -> Image.Image: """目标检测预测函数""" if not det_model: return input_img results = det_model.predict( source=input_img, conf=conf, iou=0.5, verbose=False ) plotted_image_np = results[0].plot() return Image.fromarray(plotted_image_np) # --- 导出标签列表供 Gradio 界面使用 --- ALL_CLS_LABELS = cls_labels if 'cls_labels' in locals() else ["Classification labels not available."] ALL_DET_LABELS = det_labels if 'det_labels' in locals() else {"Error": "Detection labels not available."} ALL_SEG_COLOR_MAP = COLOR_MAP_DICT if 'COLOR_MAP_DICT' in locals() else {"Error": "#000000"} ALL_SEG_LABELS = seg_labels if 'seg_labels' in locals() else ["Segmentation labels not available."]