|
|
import os
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from PIL import Image
|
|
|
from torchvision import transforms
|
|
|
from transformers import (
|
|
|
AutoImageProcessor,
|
|
|
AutoModelForImageClassification,
|
|
|
SegformerImageProcessor,
|
|
|
SegformerForSemanticSegmentation,
|
|
|
)
|
|
|
import numpy as np
|
|
|
from ultralytics import YOLO
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
CLS_MODEL_PATH = "models/vit_base"
|
|
|
SEG_MODEL_PATH = "models/segformer_b1"
|
|
|
DET_MODEL_PATH = "yolov8n.pt"
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_PATH, use_fast=True)
|
|
|
cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_PATH)
|
|
|
cls_model.eval()
|
|
|
|
|
|
cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)]
|
|
|
|
|
|
print("✅ Classification model (ViT) loaded from local path.")
|
|
|
except Exception as e:
|
|
|
|
|
|
CLS_MODEL_ID = "google/vit-base-patch16-224"
|
|
|
try:
|
|
|
cls_processor = AutoImageProcessor.from_pretrained(CLS_MODEL_ID, use_fast=True)
|
|
|
cls_model = AutoModelForImageClassification.from_pretrained(CLS_MODEL_ID)
|
|
|
cls_model.eval()
|
|
|
cls_labels = [cls_model.config.id2label[i] for i in range(cls_model.config.num_labels)]
|
|
|
print("⚠️ Classification model loaded from remote Hub (Fallback).")
|
|
|
except Exception as e_fb:
|
|
|
print(f"❌ Error loading classification model: {e}")
|
|
|
cls_model = None
|
|
|
cls_labels = []
|
|
|
|
|
|
|
|
|
def cls_predict(input_img: Image.Image) -> dict:
|
|
|
"""图像分类预测函数"""
|
|
|
if not cls_model: return {"Error": "Classification model not loaded."}
|
|
|
|
|
|
inputs = cls_processor(images=input_img, return_tensors="pt")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = cls_model(**inputs)
|
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
|
|
|
|
|
|
top5_prob, top5_indices = torch.topk(probabilities, 5)
|
|
|
|
|
|
results = {}
|
|
|
for i in range(5):
|
|
|
label = cls_labels[top5_indices[0][i].item()]
|
|
|
prob = top5_prob[0][i].item()
|
|
|
results[label] = round(prob, 4)
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_PATH)
|
|
|
seg_model = SegformerForSemanticSegmentation.from_pretrained(
|
|
|
SEG_MODEL_PATH,
|
|
|
use_safetensors=False,
|
|
|
trust_remote_code=True
|
|
|
)
|
|
|
seg_model.eval()
|
|
|
|
|
|
|
|
|
seg_labels = [label for label in seg_model.config.id2label.values()]
|
|
|
SEG_CLASSES = len(seg_labels)
|
|
|
|
|
|
|
|
|
cmap = plt.cm.get_cmap('hsv', SEG_CLASSES)
|
|
|
COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3]
|
|
|
COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8)
|
|
|
|
|
|
COLOR_MAP_DICT = {}
|
|
|
for i, label in enumerate(seg_labels):
|
|
|
rgb_color = COLOR_MAP_int[i].tolist()
|
|
|
hex_color = '#%02x%02x%02x' % tuple(rgb_color)
|
|
|
COLOR_MAP_DICT[label] = hex_color
|
|
|
|
|
|
COLOR_MAP = COLOR_MAP_int
|
|
|
|
|
|
print("✅ Segmentation model (SegFormer) loaded from local path.")
|
|
|
except Exception as e:
|
|
|
|
|
|
SEG_MODEL_ID = "nvidia/segformer-b1-finetuned-ade-512-512"
|
|
|
try:
|
|
|
seg_processor = SegformerImageProcessor.from_pretrained(SEG_MODEL_ID)
|
|
|
seg_model = SegformerForSemanticSegmentation.from_pretrained(SEG_MODEL_ID, use_safetensors=False,
|
|
|
trust_remote_code=True)
|
|
|
seg_model.eval()
|
|
|
seg_labels = [label for label in seg_model.config.id2label.values()]
|
|
|
SEG_CLASSES = len(seg_labels)
|
|
|
cmap = plt.cm.get_cmap('hsv', SEG_CLASSES)
|
|
|
COLOR_MAP_float = cmap(np.arange(SEG_CLASSES))[:, :3]
|
|
|
COLOR_MAP_int = (COLOR_MAP_float * 255).astype(np.uint8)
|
|
|
COLOR_MAP_DICT = {}
|
|
|
for i, label in enumerate(seg_labels):
|
|
|
rgb_color = COLOR_MAP_int[i].tolist()
|
|
|
hex_color = '#%02x%02x%02x' % tuple(rgb_color)
|
|
|
COLOR_MAP_DICT[label] = hex_color
|
|
|
COLOR_MAP = COLOR_MAP_int
|
|
|
print("⚠️ Segmentation model loaded from remote Hub (Fallback).")
|
|
|
except Exception as e_fb:
|
|
|
print(f"❌ Error loading segmentation model: {e}")
|
|
|
seg_model = None
|
|
|
seg_labels = []
|
|
|
COLOR_MAP_DICT = {"Error": "#000000"}
|
|
|
SEG_CLASSES = 0
|
|
|
|
|
|
|
|
|
def seg_predict(input_img: Image.Image) -> Image.Image:
|
|
|
"""语义分割预测函数"""
|
|
|
if not seg_model: return input_img
|
|
|
|
|
|
original_size = input_img.size
|
|
|
|
|
|
inputs = seg_processor(images=input_img, return_tensors="pt")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = seg_model(**inputs)
|
|
|
|
|
|
logits = outputs.logits.cpu()
|
|
|
|
|
|
predicted_segmentation = F.interpolate(
|
|
|
logits,
|
|
|
size=original_size[::-1],
|
|
|
mode='bilinear',
|
|
|
align_corners=False
|
|
|
).argmax(dim=1)[0]
|
|
|
|
|
|
|
|
|
segmentation_np = predicted_segmentation.numpy().astype(np.uint8)
|
|
|
|
|
|
|
|
|
if SEG_CLASSES > 0:
|
|
|
colored_mask = COLOR_MAP[segmentation_np % SEG_CLASSES]
|
|
|
else:
|
|
|
colored_mask = np.zeros((*segmentation_np.shape, 3), dtype=np.uint8)
|
|
|
|
|
|
mask_img = Image.fromarray(colored_mask).convert("RGBA")
|
|
|
|
|
|
|
|
|
input_img = input_img.convert("RGBA")
|
|
|
output_img = Image.blend(input_img, mask_img, alpha=0.5).convert("RGB")
|
|
|
|
|
|
return output_img
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
det_model = YOLO(DET_MODEL_PATH)
|
|
|
|
|
|
|
|
|
det_labels_dict = det_model.names
|
|
|
det_labels = {str(k): v for k, v in det_labels_dict.items()}
|
|
|
|
|
|
print("✅ Detection model (YOLOv8n) loaded.")
|
|
|
except Exception as e:
|
|
|
print(f"❌ Error loading detection model: {e}")
|
|
|
det_model = None
|
|
|
det_labels = {}
|
|
|
|
|
|
|
|
|
def det_predict(input_img: Image.Image, conf: float) -> Image.Image:
|
|
|
"""目标检测预测函数"""
|
|
|
if not det_model: return input_img
|
|
|
|
|
|
results = det_model.predict(
|
|
|
source=input_img,
|
|
|
conf=conf,
|
|
|
iou=0.5,
|
|
|
verbose=False
|
|
|
)
|
|
|
|
|
|
plotted_image_np = results[0].plot()
|
|
|
return Image.fromarray(plotted_image_np)
|
|
|
|
|
|
|
|
|
|
|
|
ALL_CLS_LABELS = cls_labels if 'cls_labels' in locals() else ["Classification labels not available."]
|
|
|
ALL_DET_LABELS = det_labels if 'det_labels' in locals() else {"Error": "Detection labels not available."}
|
|
|
ALL_SEG_COLOR_MAP = COLOR_MAP_DICT if 'COLOR_MAP_DICT' in locals() else {"Error": "#000000"}
|
|
|
ALL_SEG_LABELS = seg_labels if 'seg_labels' in locals() else ["Segmentation labels not available."] |