Spaces:
Runtime error
Runtime error
File size: 10,748 Bytes
d686824 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
import os
import cv2
import os.path as osp
from mmengine.config import Config
from mmengine.dataset import Compose
from mmdet.apis import init_detector
from mmdet.utils import get_test_pipeline_cfg
# from mmengine.runner.amp import autocast
from torch.amp import autocast
import torch
import supervision as sv
from typing import Dict, Optional, Sequence, List
import supervision as sv
class LabelAnnotator(sv.LabelAnnotator):
@staticmethod
def resolve_text_background_xyxy(
center_coordinates,
text_wh,
position,
):
center_x, center_y = center_coordinates
text_w, text_h = text_wh
return center_x, center_y, center_x + text_w, center_y + text_h
class YoloInterface:
def __init__(self):
"""
Initialize the YOLO-World model with the given configuration and checkpoint.
Args:
"""
pass
def set_BBoxAnnotator(self):
self.BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(thickness=1)
# MASK_ANNOTATOR = sv.MaskAnnotator()
self.LABEL_ANNOTATOR = LabelAnnotator(text_padding=4,
text_scale=0.5,
text_thickness=1)
class YoloWorldInterface(YoloInterface):
def __init__(self, config_path: str, checkpoint_path: str, device: str = "cuda:0"):
"""
Initialize the YOLO-World model with the given configuration and checkpoint.
Args:
config_path (str): Path to the model configuration file.
checkpoint_path (str): Path to the model checkpoint.
device (str): Device to run the model on (e.g., 'cuda:0', 'cpu').
"""
self.config_path = config_path
self.checkpoint_path = checkpoint_path
self.device = device
# Load configuration
cfg = Config.fromfile(config_path)
cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(config_path))[0])
cfg.load_from = checkpoint_path
# Initialize the model
self.model = init_detector(cfg, checkpoint=checkpoint_path, device=device)
self.set_BBoxAnnotator()
# Initialize the test pipeline
# build test pipeline
self.model.cfg.test_dataloader.dataset.pipeline[
0].type = 'mmdet.LoadImageFromNDArray'
self.test_pipeline = Compose(self.model.cfg.test_dataloader.dataset.pipeline)
def reparameterize_object_list(self, target_objects: List[str], cue_objects: List[str]):
"""
Reparameterize the detect object list to be used by the YOLO model.
Args:
target_objects (List[str]): List of target object names.
cue_objects (List[str]): List of cue object names.
"""
# Combine target objects and cue objects into the final text format
combined_texts = target_objects + cue_objects
# Format the text prompts for the YOLO model
self.texts = [[obj.strip()] for obj in combined_texts] + [[' ']]
# Reparameterize the YOLO model with the provided text prompts
self.model.reparameterize(self.texts)
def inference(self, image: str, max_dets: int = 100, score_threshold: float = 0.3, use_amp: bool = False):
"""
Run inference on a single image.
Args:
image (str): Path to the image.
max_dets (int): Maximum number of detections to keep.
score_threshold (float): Score threshold for filtering detections.
use_amp (bool): Whether to use mixed precision for inference.
Returns:
sv.Detections: Detection results.
"""
# Prepare data for inference
data_info = dict(img_id=0, img_path=image, texts=self.texts)
data_info = self.test_pipeline(data_info)
data_batch = dict(inputs=data_info['inputs'].unsqueeze(0),
data_samples=[data_info['data_samples']])
# Run inference
with autocast(enabled=use_amp), torch.no_grad():
output = self.model.test_step(data_batch)[0]
pred_instances = output.pred_instances
pred_instances = pred_instances[pred_instances.scores.float() > score_threshold]
if len(pred_instances.scores) > max_dets:
indices = pred_instances.scores.float().topk(max_dets)[1]
pred_instances = pred_instances[indices]
pred_instances = pred_instances.cpu().numpy()
# Process detections
detections = sv.Detections(
xyxy=pred_instances['bboxes'],
class_id=pred_instances['labels'],
confidence=pred_instances['scores'],
mask=pred_instances.get('masks', None)
)
return detections
def inference_detector(self, images, max_dets=50, score_threshold=0.2, use_amp: bool = False):
data_info = dict(img_id=0, img=images[0], texts=self.texts) #TBD for batch searching
data_info = self.test_pipeline(data_info)
data_batch = dict(inputs=data_info['inputs'].unsqueeze(0),
data_samples=[data_info['data_samples']])
detections_inbatch = []
with torch.no_grad():
outputs = self.model.test_step(data_batch)
# cover to searcher interface format
for output in outputs:
pred_instances = output.pred_instances
pred_instances = pred_instances[pred_instances.scores.float() >
score_threshold]
if len(pred_instances.scores) > max_dets:
indices = pred_instances.scores.float().topk(max_dets)[1]
pred_instances = pred_instances[indices]
output.pred_instances = pred_instances
if 'masks' in pred_instances:
masks = pred_instances['masks']
else:
masks = None
pred_instances = pred_instances.cpu().numpy()
detections = sv.Detections(xyxy=pred_instances['bboxes'],
class_id=pred_instances['labels'],
confidence=pred_instances['scores'],
mask=masks)
detections_inbatch.append(detections)
self.detect_outputs_raw = outputs
self.detections_inbatch = detections_inbatch
return detections_inbatch
def bbox_visualization(self, images, detections_inbatch):
anno_images = []
# detections_inbatch = self.detections_inbatch
for b, detections in enumerate(detections_inbatch):
texts = self.texts
labels = [
f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in
zip(detections.class_id, detections.confidence)
]
index = len(detections_inbatch) -1
image = images[index]
anno_image = image.copy()
anno_image = self.BOUNDING_BOX_ANNOTATOR.annotate(anno_image, detections)
anno_image = self.LABEL_ANNOTATOR.annotate(anno_image, detections, labels=labels)
anno_images.append(anno_image)
return anno_images
import torch
from typing import List
import supervision as sv # 确保已安装 Supervision 库
import os.path as osp
class YoloV5Interface(YoloInterface):
def __init__(self,config_path="ultralytics/yolov5", checkpoint_path: str = 'yolov5s', device: str = 'cuda:0'):
"""
初始化 YOLOv5 模型。
Args:
model_name (str): YOLOv5 模型变体名称(如 'yolov5s', 'yolov5m', 'yolov5l', 'yolov5x')。
device (str): 运行模型的设备(如 'cuda:0', 'cpu')。
"""
self.device = device
self.model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
self.model.to(self.device)
self.model.eval()
self.target_classes = None # 用于存储目标类别列表
self.texts = None
self.test_pipeline = None
def reparameterize_object_list(self, target_objects: List[str], cue_objects: List[str]):
"""
重新参数化检测对象列表,以便在推理时使用。
Args:
target_objects (List[str]): 目标对象名称列表。
cue_objects (List[str]): 线索对象名称列表。
"""
# 合并目标对象和线索对象
combined_objects = target_objects + cue_objects
self.target_classes = combined_objects
def inference(self, images: str, max_dets: int = 100, score_threshold: float = 0.3, use_amp: bool = False):
"""
对单张图像运行推理。
Args:
image (str): 图像路径。
max_dets (int): 保留的最大检测数量。
score_threshold (float): 过滤检测的分数阈值。
use_amp (bool): 是否使用混合精度进行推理。
Returns:
sv.Detections: 检测结果。
"""
results = self.model(images, size=640) # 可以根据需要调整输入尺寸
# 提取检测结果(假设批量大小为 1)
detections_batch = results.pred # B tensors of shape (N, 6) [x1, y1, x2, y2, confidence, class]
# 应用分数阈值
# 用于存储每个批次过滤后的检测结果
filtered_detections = []
for detections in detections_batch:
# 应用分数阈值,过滤掉 confidence <= score_threshold 的检测
detections = detections[detections[:, 4] > score_threshold]
# 如果设置了 topk,截取前 topk 个检测
if len(detections) > max_dets:
detections = detections[:max_dets]
# 如果设置了目标类别,过滤检测结果
if self.target_classes is not None:
# 获取所有类别名称
class_names = self.model.names
# 获取目标类别的类别ID
target_class_ids = [i for i, name in class_names.items() if name in self.target_classes]
# 过滤检测结果
detections = detections[[cls in target_class_ids for cls in detections[:, 5]]]
# 转换为 Supervision 库的 Detections 对象
detections = sv.Detections(
xyxy=detections[:, :4].cpu().numpy(),
confidence=detections[:, 4].cpu().numpy(),
class_id=detections[:, 5].cpu().numpy().astype(int)
)
filtered_detections.append(detections)
return filtered_detections
|