File size: 3,655 Bytes
2bc3168 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | from __future__ import annotations
from dataclasses import dataclass
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from .ijepa_localization import IJepaPatchLocalizer
from .obstacle_dataset import YoloBox
@dataclass
class ObjectContextResult:
object_index: int
class_name: str
object_context_similarity: float
scene_context_similarity: float
context_pattern: str
context_strength: str
object_box_xyxy: tuple[int, int, int, int]
context_box_xyxy: tuple[int, int, int, int]
def analyze_object_contexts(
image: Image.Image,
yolo_boxes: list[YoloBox],
localizer: IJepaPatchLocalizer,
margin_ratio: float = 0.75,
) -> list[ObjectContextResult]:
if not yolo_boxes:
return []
scene_embedding = localizer.embed_image(image)
results = []
for index, box in enumerate(yolo_boxes, start=1):
object_box = box.to_xyxy(*image.size)
context_box = expand_box(object_box, image.size, margin_ratio=margin_ratio)
object_crop = image.crop(object_box)
context_crop = image.crop(context_box)
if object_crop.width < 2 or object_crop.height < 2 or context_crop.width < 2 or context_crop.height < 2:
continue
object_embedding = localizer.embed_image(object_crop)
context_embedding = localizer.embed_image(context_crop)
object_context_similarity = vector_similarity(object_embedding, context_embedding)
scene_context_similarity = vector_similarity(scene_embedding, context_embedding)
results.append(
ObjectContextResult(
object_index=index,
class_name=box.class_name,
object_context_similarity=object_context_similarity,
scene_context_similarity=scene_context_similarity,
context_pattern=describe_context_pattern(
object_context_similarity,
scene_context_similarity,
),
context_strength=describe_context_strength(
object_context_similarity,
scene_context_similarity,
),
object_box_xyxy=object_box,
context_box_xyxy=context_box,
)
)
return results
def expand_box(
box: tuple[int, int, int, int],
image_size: tuple[int, int],
margin_ratio: float,
) -> tuple[int, int, int, int]:
image_width, image_height = image_size
x1, y1, x2, y2 = box
width = x2 - x1
height = y2 - y1
margin_x = width * margin_ratio
margin_y = height * margin_ratio
return (
int(max(0, x1 - margin_x)),
int(max(0, y1 - margin_y)),
int(min(image_width - 1, x2 + margin_x)),
int(min(image_height - 1, y2 + margin_y)),
)
def vector_similarity(left, right) -> float:
return float(cosine_similarity([left], [right])[0, 0])
def describe_context_pattern(object_context_similarity: float, scene_context_similarity: float) -> str:
if object_context_similarity >= 0.92 and scene_context_similarity >= 0.92:
return "near other objects / scene-embedded"
if object_context_similarity >= 0.9:
return "isolated / object-dominant"
if scene_context_similarity >= 0.9:
return "group / crowd context"
return "distinct object/context"
def describe_context_strength(object_context_similarity: float, scene_context_similarity: float) -> str:
strength = max(object_context_similarity, scene_context_similarity)
if strength >= 0.93:
return "high"
if strength >= 0.86:
return "medium"
return "low"
|