ai4vision / model.py
2ephyrh's picture
Upload 18 files
3f574b1 verified
import os
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
SegformerImageProcessor,
SegformerForSemanticSegmentation,
)
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
# --- 本地模型路径配置 (确保您已将文件下载到以下路径) ---
CLS_MODEL_PATH = "models/vit_base"
SEG_MODEL_PATH = "models/segformer_b1"
DET_MODEL_PATH = "yolov8n.pt" # YOLOv8 会自动管理下载,我们恢复到简单ID
# --- 1. 图像分类模型 (ViT) ---
try:
# 从本地路径加载
cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_PATH, use_fast=True)
cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_PATH)
cls_model.eval()
cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)]
print("✅ Classification model (ViT) loaded from local path.")
except Exception as e:
# 如果本地路径失败,尝试远程ID作为回退(仅在开发阶段有用)
CLS_MODEL_ID = "google/vit-base-patch16-224"
try:
cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_ID, use_fast=True)
cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_ID)
cls_model.eval()
cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)]
print("⚠️ Classification model loaded from remote Hub (Fallback).")
except Exception as e_fb:
print(f"❌ Error loading classification model: {e}")
cls_model = None
cls_labels = []
def cls_predict(input_img: Image.Image) -> dict:
"""图像分类预测函数"""
if not cls_model: return {"Error": "Classification model not loaded."}
inputs = cls_processor(images=input_img, return_tensors="pt")
with torch.no_grad():
outputs = cls_model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
top5_prob, top5_indices = torch.topk(probabilities, 5)
results = {}
for i in range(5):
label = cls_labels[top5_indices[0][i].item()]
prob = top5_prob[0][i].item()
results[label] = round(prob, 4)
return results
# --- 2. 语义分割模型 (SegFormer) ---
try:
# 从本地路径加载
seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_PATH)
seg_model = SegformerForSemanticSegmentation.from_pretrained(
SEG_MODEL_PATH,
use_safetensors=False, # 确保使用 pytorch_model.bin
trust_remote_code=True
)
seg_model.eval()
# 获取 SegFormer (ADE20K) 的类别标签
seg_labels = [label for label in seg_model.config.id2label.values()]
SEG_CLASSES = len(seg_labels)
# 创建固定的颜色映射 (Matplotlib HSV Colormap)
cmap = plt.cm.get_cmap('hsv', SEG_CLASSES)
COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3]
COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8)
COLOR_MAP_DICT = {}
for i, label in enumerate(seg_labels):
rgb_color = COLOR_MAP_int[i].tolist()
hex_color = '#%02x%02x%02x' % tuple(rgb_color)
COLOR_MAP_DICT[label] = hex_color
COLOR_MAP = COLOR_MAP_int
print("✅ Segmentation model (SegFormer) loaded from local path.")
except Exception as e:
# 远程回退加载
SEG_MODEL_ID = "nvidia/segformer-b1-finetuned-ade-512-512"
try:
seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_ID)
seg_model = SegformerForSemanticSegmentation.from_pretrained(SEG_MODEL_ID, use_safetensors=False,
trust_remote_code=True)
seg_model.eval()
seg_labels = [label for label in seg_model.config.id2label.values()]
SEG_CLASSES = len(seg_labels)
cmap = plt.cm.get_cmap('hsv', SEG_CLASSES)
COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3]
COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8)
COLOR_MAP_DICT = {}
for i, label in enumerate(seg_labels):
rgb_color = COLOR_MAP_int[i].tolist()
hex_color = '#%02x%02x%02x' % tuple(rgb_color)
COLOR_MAP_DICT[label] = hex_color
COLOR_MAP = COLOR_MAP_int
print("⚠️ Segmentation model loaded from remote Hub (Fallback).")
except Exception as e_fb:
print(f"❌ Error loading segmentation model: {e}")
seg_model = None
seg_labels = []
COLOR_MAP_DICT = {"Error": "#000000"}
SEG_CLASSES = 0
def seg_predict(input_img: Image.Image) -> Image.Image:
"""语义分割预测函数"""
if not seg_model: return input_img
original_size = input_img.size
inputs = seg_processor(images=input_img, return_tensors="pt")
with torch.no_grad():
outputs = seg_model(**inputs)
logits = outputs.logits.cpu()
predicted_segmentation = F.interpolate(
logits,
size=original_size[::-1],
mode='bilinear',
align_corners=False
).argmax(dim=1)[0]
# 可视化:创建彩色掩码
segmentation_np = predicted_segmentation.numpy().astype(np.uint8)
# 映射到颜色
if SEG_CLASSES > 0:
colored_mask = COLOR_MAP[segmentation_np % SEG_CLASSES]
else:
colored_mask = np.zeros((*segmentation_np.shape, 3), dtype=np.uint8)
mask_img = Image.fromarray(colored_mask).convert("RGBA")
# 叠加到原图上 (使用 50% 透明度)
input_img = input_img.convert("RGBA")
output_img = Image.blend(input_img, mask_img, alpha=0.5).convert("RGB")
return output_img
# --- 3. 目标检测模型 (YOLOv8) ---
try:
# 保持使用 ID,让 YOLOv8 自动管理下载和缓存
det_model = YOLO(DET_MODEL_PATH)
# 获取 YOLOv8 (COCO) 的类别标签
det_labels_dict = det_model.names
det_labels = {str(k): v for k, v in det_labels_dict.items()}
print("✅ Detection model (YOLOv8n) loaded.")
except Exception as e:
print(f"❌ Error loading detection model: {e}")
det_model = None
det_labels = {}
def det_predict(input_img: Image.Image, conf: float) -> Image.Image:
"""目标检测预测函数"""
if not det_model: return input_img
results = det_model.predict(
source=input_img,
conf=conf,
iou=0.5,
verbose=False
)
plotted_image_np = results[0].plot()
return Image.fromarray(plotted_image_np)
# --- 导出标签列表供 Gradio 界面使用 ---
ALL_CLS_LABELS = cls_labels if 'cls_labels' in locals() else ["Classification labels not available."]
ALL_DET_LABELS = det_labels if 'det_labels' in locals() else {"Error": "Detection labels not available."}
ALL_SEG_COLOR_MAP = COLOR_MAP_DICT if 'COLOR_MAP_DICT' in locals() else {"Error": "#000000"}
ALL_SEG_LABELS = seg_labels if 'seg_labels' in locals() else ["Segmentation labels not available."]