| """Tests vision analīzes endpointiem.""" |
|
|
| from __future__ import annotations |
|
|
| import base64 |
| import io |
| from unittest.mock import AsyncMock, patch |
|
|
| import pytest |
| from PIL import Image |
|
|
| from maris_core.memory_context import ConversationMemoryStore |
| from maris_core.vision.analyze import ( |
| _LIVE_CAMERAS, |
| BoundingBox, |
| FrameAnalysis, |
| FrameSequenceRequest, |
| ImageSourceRequest, |
| LiveCameraConfigRequest, |
| LiveCameraConnectRequest, |
| LiveFrameRequest, |
| LiveSessionCommandRequest, |
| OCRTextBlock, |
| VisionDetection, |
| analyze_frames, |
| analyze_image, |
| configure_live_camera, |
| connect_live_camera, |
| estimate_pose, |
| list_live_cameras, |
| ocr_image, |
| process_live_frame, |
| scene_timeline, |
| start_live_camera, |
| track_objects, |
| ) |
|
|
|
|
| def _sample_image_base64( |
| size: tuple[int, int] = (16, 12), |
| color: tuple[int, int, int] = (220, 80, 80), |
| ) -> str: |
| image = Image.new("RGB", size, color) |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| return base64.b64encode(buffer.getvalue()).decode() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_analyze_image_fallback_without_detector() -> None: |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch( |
| "maris_core.vision.analyze._detect_image_payload", |
| return_value=([], "fallback/basic-image-summary", True), |
| ), |
| ): |
| response = await analyze_image(ImageSourceRequest(image_base64=_sample_image_base64())) |
|
|
| assert response.fallback_used is True |
| assert response.width == 16 |
| assert response.height == 12 |
| assert response.detections == [] |
| assert "Fallback vision summary" in response.summary |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_analyze_image_maps_detector_output() -> None: |
| detections = [ |
| VisionDetection( |
| label="person", |
| confidence=0.93, |
| bbox=BoundingBox(x=1, y=2, width=10, height=12), |
| ), |
| VisionDetection( |
| label="cell phone", |
| confidence=0.42, |
| bbox=BoundingBox(x=4, y=5, width=4, height=5), |
| ), |
| ] |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch( |
| "maris_core.vision.analyze._detect_image_payload", |
| return_value=(detections, "facebook/detr-resnet-50", False), |
| ), |
| ): |
| response = await analyze_image( |
| ImageSourceRequest(image_base64=_sample_image_base64(size=(20, 18))) |
| ) |
|
|
| assert response.fallback_used is False |
| assert response.model == "facebook/detr-resnet-50" |
| assert [item.label for item in response.detections] == ["person", "cell phone"] |
| assert "person" in response.summary |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_analyze_image_persists_summary_into_shared_session_memory() -> None: |
| memory = ConversationMemoryStore() |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch( |
| "maris_core.vision.analyze._detect_image_payload", |
| return_value=([], "fallback/basic-image-summary", True), |
| ), |
| patch("maris_core.vision.analyze.memory_store", memory), |
| ): |
| response = await analyze_image( |
| ImageSourceRequest( |
| image_base64=_sample_image_base64(), |
| session_id="vision-session", |
| camera_id="cam-1", |
| ) |
| ) |
|
|
| matches = memory.retrieve_relevant_context("vision-session", "Fallback vision summary") |
| assert matches |
| assert response.summary == matches[0].content |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_ocr_image_returns_text_blocks() -> None: |
| blocks = [ |
| OCRTextBlock( |
| text="MARIS AI", |
| confidence=0.88, |
| bbox=BoundingBox(x=2, y=3, width=40, height=12), |
| language="lv", |
| ) |
| ] |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch( |
| "maris_core.vision.analyze._extract_ocr_blocks", |
| return_value=(blocks, "pytesseract", False), |
| ), |
| ): |
| response = await ocr_image( |
| ImageSourceRequest(image_base64=_sample_image_base64(size=(64, 32))) |
| ) |
|
|
| assert response.model == "pytesseract" |
| assert response.fallback_used is False |
| assert response.results[0].text == "MARIS AI" |
| assert "OCR pabeigts" in response.summary |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_pose_estimate_derives_keypoints_from_person_detections() -> None: |
| detections = [ |
| VisionDetection( |
| label="person", |
| confidence=0.91, |
| bbox=BoundingBox(x=10, y=20, width=40, height=120), |
| ) |
| ] |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch( |
| "maris_core.vision.analyze._detect_image_payload", |
| return_value=(detections, "facebook/detr-resnet-50", False), |
| ), |
| ): |
| response = await estimate_pose( |
| ImageSourceRequest(image_base64=_sample_image_base64(size=(100, 160))) |
| ) |
|
|
| assert response.model == "bbox-derived-pose-v1" |
| assert response.fallback_used is False |
| assert len(response.poses) == 1 |
| assert any(point.name == "nose" for point in response.poses[0].keypoints) |
| assert any(connection.start == "nose" for connection in response.poses[0].connections) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_tracking_builds_tracks_from_frame_analysis() -> None: |
| frames = [Image.new("RGB", (80, 60), (20, 20, 20)), Image.new("RGB", (80, 60), (30, 30, 30))] |
| analyses = [ |
| FrameAnalysis( |
| frame_index=0, |
| summary="frame 0", |
| detections=[ |
| VisionDetection( |
| label="person", |
| confidence=0.9, |
| bbox=BoundingBox(x=10, y=10, width=20, height=40), |
| ) |
| ], |
| dominant_labels=["person"], |
| brightness=90.0, |
| ), |
| FrameAnalysis( |
| frame_index=1, |
| summary="frame 1", |
| detections=[ |
| VisionDetection( |
| label="person", |
| confidence=0.87, |
| bbox=BoundingBox(x=14, y=11, width=20, height=40), |
| ) |
| ], |
| dominant_labels=["person"], |
| brightness=94.0, |
| ), |
| ] |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch("maris_core.vision.analyze._load_frames", new=AsyncMock(return_value=frames)), |
| patch( |
| "maris_core.vision.analyze._build_frame_analysis", |
| return_value=(analyses, "facebook/detr-resnet-50", False), |
| ), |
| ): |
| response = await track_objects( |
| FrameSequenceRequest(frames_base64=[_sample_image_base64(), _sample_image_base64()]) |
| ) |
|
|
| assert response.frame_count == 2 |
| assert len(response.tracks) == 1 |
| assert response.tracks[0].track_id == 1 |
| assert len(response.tracks[0].observations) == 2 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_scene_timeline_groups_frames_into_scenes() -> None: |
| frames = [Image.new("RGB", (80, 60), (20, 20, 20)) for _ in range(3)] |
| analyses = [ |
| FrameAnalysis( |
| frame_index=0, |
| summary="scene a", |
| detections=[], |
| dominant_labels=["person"], |
| brightness=80.0, |
| ), |
| FrameAnalysis( |
| frame_index=1, |
| summary="scene a2", |
| detections=[], |
| dominant_labels=["person", "chair"], |
| brightness=84.0, |
| ), |
| FrameAnalysis( |
| frame_index=2, |
| summary="scene b", |
| detections=[], |
| dominant_labels=["car"], |
| brightness=150.0, |
| ), |
| ] |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch("maris_core.vision.analyze._load_frames", new=AsyncMock(return_value=frames)), |
| patch( |
| "maris_core.vision.analyze._build_frame_analysis", |
| return_value=(analyses, "facebook/detr-resnet-50", False), |
| ), |
| ): |
| response = await scene_timeline( |
| FrameSequenceRequest( |
| frames_base64=[ |
| _sample_image_base64(), |
| _sample_image_base64(), |
| _sample_image_base64(), |
| ] |
| ) |
| ) |
|
|
| assert response.frame_count == 3 |
| assert len(response.scenes) == 2 |
| assert response.scenes[0].start_frame == 0 |
| assert response.scenes[1].start_frame == 2 |
| assert "Scene timeline pabeigta" in response.summary |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_frame_analysis_returns_per_frame_summaries() -> None: |
| frames = [Image.new("RGB", (80, 60), (20, 20, 20))] |
| analyses = [ |
| FrameAnalysis( |
| frame_index=0, |
| summary="frame one", |
| detections=[], |
| dominant_labels=[], |
| brightness=42.0, |
| ) |
| ] |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch("maris_core.vision.analyze._load_frames", new=AsyncMock(return_value=frames)), |
| patch( |
| "maris_core.vision.analyze._build_frame_analysis", |
| return_value=(analyses, "fallback/basic-image-summary", True), |
| ), |
| ): |
| response = await analyze_frames( |
| FrameSequenceRequest(frames_base64=[_sample_image_base64()]) |
| ) |
|
|
| assert response.frame_count == 1 |
| assert response.frames[0].summary == "frame one" |
| assert response.fallback_used is True |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_live_camera_connect_and_list_registry() -> None: |
| _LIVE_CAMERAS.clear() |
| with patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ): |
| response = await connect_live_camera( |
| LiveCameraConnectRequest( |
| camera_id="cam-browser", |
| source_type="browser_camera", |
| transport="getUserMedia", |
| device_id="device-1", |
| ) |
| ) |
| registry = await list_live_cameras() |
|
|
| assert response.camera.camera_id == "cam-browser" |
| assert response.camera.status == "connected" |
| assert registry.cameras[0].camera_id == "cam-browser" |
| assert registry.cameras[0].health.connected is True |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_live_frame_processing_updates_tracks_timeline_and_events() -> None: |
| _LIVE_CAMERAS.clear() |
| sample = _sample_image_base64(size=(40, 30)) |
| detections = [ |
| VisionDetection( |
| label="person", |
| confidence=0.91, |
| bbox=BoundingBox(x=5, y=4, width=18, height=22), |
| ) |
| ] |
| with ( |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| patch( |
| "maris_core.vision.analyze._frame_detections", |
| return_value=(detections, "facebook/detr-resnet-50", False), |
| ), |
| patch( |
| "maris_core.vision.analyze._extract_ocr_blocks", |
| return_value=([], "pytesseract", False), |
| ), |
| ): |
| await connect_live_camera( |
| LiveCameraConnectRequest( |
| camera_id="cam-live", |
| source_type="browser_camera", |
| transport="getUserMedia", |
| device_id="device-1", |
| ) |
| ) |
| await start_live_camera(LiveSessionCommandRequest(camera_id="cam-live")) |
| response = await process_live_frame( |
| LiveFrameRequest(camera_id="cam-live", image_base64=sample, frame_index=0) |
| ) |
|
|
| assert response.camera.status == "streaming" |
| assert response.camera.health.analysis_active is True |
| assert response.camera.latest_result["detections"][0]["label"] == "person" |
| assert response.camera.tracks[0].label == "person" |
| assert response.events[0].type == "analysis_result" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_live_camera_config_updates_roi_and_rules() -> None: |
| _LIVE_CAMERAS.clear() |
| with patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ): |
| await connect_live_camera( |
| LiveCameraConnectRequest( |
| camera_id="cam-config", |
| source_type="ip_camera", |
| transport="rtsp", |
| url="rtsp://camera.local/live", |
| ) |
| ) |
| response = await configure_live_camera( |
| LiveCameraConfigRequest( |
| camera_id="cam-config", |
| roi_zones=[{"label": "gate", "x": 10, "y": 20, "width": 30, "height": 40}], |
| alert_rules=["person_zone:person:1:0.7"], |
| fps_budget=8.0, |
| ) |
| ) |
|
|
| assert response.camera.roi_zones[0]["label"] == "gate" |
| assert response.camera.alert_rules == ["person_zone:person:1:0.7"] |
| assert response.camera.fps_budget == 8.0 |
|
|