File size: 3,655 Bytes
2bc3168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import annotations

from dataclasses import dataclass

from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity

from .ijepa_localization import IJepaPatchLocalizer
from .obstacle_dataset import YoloBox


@dataclass
class ObjectContextResult:
    object_index: int
    class_name: str
    object_context_similarity: float
    scene_context_similarity: float
    context_pattern: str
    context_strength: str
    object_box_xyxy: tuple[int, int, int, int]
    context_box_xyxy: tuple[int, int, int, int]


def analyze_object_contexts(
    image: Image.Image,
    yolo_boxes: list[YoloBox],
    localizer: IJepaPatchLocalizer,
    margin_ratio: float = 0.75,
) -> list[ObjectContextResult]:
    if not yolo_boxes:
        return []

    scene_embedding = localizer.embed_image(image)
    results = []
    for index, box in enumerate(yolo_boxes, start=1):
        object_box = box.to_xyxy(*image.size)
        context_box = expand_box(object_box, image.size, margin_ratio=margin_ratio)
        object_crop = image.crop(object_box)
        context_crop = image.crop(context_box)
        if object_crop.width < 2 or object_crop.height < 2 or context_crop.width < 2 or context_crop.height < 2:
            continue
        object_embedding = localizer.embed_image(object_crop)
        context_embedding = localizer.embed_image(context_crop)

        object_context_similarity = vector_similarity(object_embedding, context_embedding)
        scene_context_similarity = vector_similarity(scene_embedding, context_embedding)
        results.append(
            ObjectContextResult(
                object_index=index,
                class_name=box.class_name,
                object_context_similarity=object_context_similarity,
                scene_context_similarity=scene_context_similarity,
                context_pattern=describe_context_pattern(
                    object_context_similarity,
                    scene_context_similarity,
                ),
                context_strength=describe_context_strength(
                    object_context_similarity,
                    scene_context_similarity,
                ),
                object_box_xyxy=object_box,
                context_box_xyxy=context_box,
            )
        )
    return results


def expand_box(
    box: tuple[int, int, int, int],
    image_size: tuple[int, int],
    margin_ratio: float,
) -> tuple[int, int, int, int]:
    image_width, image_height = image_size
    x1, y1, x2, y2 = box
    width = x2 - x1
    height = y2 - y1
    margin_x = width * margin_ratio
    margin_y = height * margin_ratio
    return (
        int(max(0, x1 - margin_x)),
        int(max(0, y1 - margin_y)),
        int(min(image_width - 1, x2 + margin_x)),
        int(min(image_height - 1, y2 + margin_y)),
    )


def vector_similarity(left, right) -> float:
    return float(cosine_similarity([left], [right])[0, 0])


def describe_context_pattern(object_context_similarity: float, scene_context_similarity: float) -> str:
    if object_context_similarity >= 0.92 and scene_context_similarity >= 0.92:
        return "near other objects / scene-embedded"
    if object_context_similarity >= 0.9:
        return "isolated / object-dominant"
    if scene_context_similarity >= 0.9:
        return "group / crowd context"
    return "distinct object/context"


def describe_context_strength(object_context_similarity: float, scene_context_similarity: float) -> str:
    strength = max(object_context_similarity, scene_context_similarity)
    if strength >= 0.93:
        return "high"
    if strength >= 0.86:
        return "medium"
    return "low"