| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| from PIL import Image |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| from .ijepa_localization import IJepaPatchLocalizer |
| from .obstacle_dataset import YoloBox |
|
|
|
|
| @dataclass |
| class ObjectContextResult: |
| object_index: int |
| class_name: str |
| object_context_similarity: float |
| scene_context_similarity: float |
| context_pattern: str |
| context_strength: str |
| object_box_xyxy: tuple[int, int, int, int] |
| context_box_xyxy: tuple[int, int, int, int] |
|
|
|
|
| def analyze_object_contexts( |
| image: Image.Image, |
| yolo_boxes: list[YoloBox], |
| localizer: IJepaPatchLocalizer, |
| margin_ratio: float = 0.75, |
| ) -> list[ObjectContextResult]: |
| if not yolo_boxes: |
| return [] |
|
|
| scene_embedding = localizer.embed_image(image) |
| results = [] |
| for index, box in enumerate(yolo_boxes, start=1): |
| object_box = box.to_xyxy(*image.size) |
| context_box = expand_box(object_box, image.size, margin_ratio=margin_ratio) |
| object_crop = image.crop(object_box) |
| context_crop = image.crop(context_box) |
| if object_crop.width < 2 or object_crop.height < 2 or context_crop.width < 2 or context_crop.height < 2: |
| continue |
| object_embedding = localizer.embed_image(object_crop) |
| context_embedding = localizer.embed_image(context_crop) |
|
|
| object_context_similarity = vector_similarity(object_embedding, context_embedding) |
| scene_context_similarity = vector_similarity(scene_embedding, context_embedding) |
| results.append( |
| ObjectContextResult( |
| object_index=index, |
| class_name=box.class_name, |
| object_context_similarity=object_context_similarity, |
| scene_context_similarity=scene_context_similarity, |
| context_pattern=describe_context_pattern( |
| object_context_similarity, |
| scene_context_similarity, |
| ), |
| context_strength=describe_context_strength( |
| object_context_similarity, |
| scene_context_similarity, |
| ), |
| object_box_xyxy=object_box, |
| context_box_xyxy=context_box, |
| ) |
| ) |
| return results |
|
|
|
|
| def expand_box( |
| box: tuple[int, int, int, int], |
| image_size: tuple[int, int], |
| margin_ratio: float, |
| ) -> tuple[int, int, int, int]: |
| image_width, image_height = image_size |
| x1, y1, x2, y2 = box |
| width = x2 - x1 |
| height = y2 - y1 |
| margin_x = width * margin_ratio |
| margin_y = height * margin_ratio |
| return ( |
| int(max(0, x1 - margin_x)), |
| int(max(0, y1 - margin_y)), |
| int(min(image_width - 1, x2 + margin_x)), |
| int(min(image_height - 1, y2 + margin_y)), |
| ) |
|
|
|
|
| def vector_similarity(left, right) -> float: |
| return float(cosine_similarity([left], [right])[0, 0]) |
|
|
|
|
| def describe_context_pattern(object_context_similarity: float, scene_context_similarity: float) -> str: |
| if object_context_similarity >= 0.92 and scene_context_similarity >= 0.92: |
| return "near other objects / scene-embedded" |
| if object_context_similarity >= 0.9: |
| return "isolated / object-dominant" |
| if scene_context_similarity >= 0.9: |
| return "group / crowd context" |
| return "distinct object/context" |
|
|
|
|
| def describe_context_strength(object_context_similarity: float, scene_context_similarity: float) -> str: |
| strength = max(object_context_similarity, scene_context_similarity) |
| if strength >= 0.93: |
| return "high" |
| if strength >= 0.86: |
| return "medium" |
| return "low" |
|
|