| """Tests teksta ģenerēšanai.""" |
|
|
| from __future__ import annotations |
|
|
| import sys |
| import time |
| from collections.abc import Mapping |
| from contextlib import contextmanager |
| from queue import Queue |
| from types import SimpleNamespace |
| from typing import Any |
| from unittest.mock import AsyncMock, patch |
|
|
| import httpx |
| import pytest |
| from fastapi import FastAPI |
| from fastapi.testclient import TestClient |
| from pydantic import ValidationError |
| from transformers import GenerationConfig, PretrainedConfig |
|
|
| import maris_core.text.generate as text_generate_module |
| from maris_core.memory_context import ConversationMemoryStore |
| from maris_core.orchestrator.routing import resolve_text_model |
| from maris_core.text.generate import ( |
| DEFAULT_MAX_NEW_TOKENS, |
| FALLBACK_MODEL_NAME, |
| GenerateRequest, |
| _sanitize_response_text, |
| call_generation_pipeline, |
| generate, |
| get_text_model_readiness, |
| ) |
| from maris_core.text.generate import ( |
| router as text_router, |
| ) |
| from maris_core.text.tools import execute_tool_trace, plan_tool_use |
|
|
|
|
| def _build_text_app() -> FastAPI: |
| app = FastAPI() |
| app.include_router(text_router, prefix="/v1/text") |
| return app |
|
|
|
|
| @contextmanager |
| def _reset_pipeline_runtime() -> Any: |
| original_pipeline = text_generate_module._pipeline |
| original_loading = text_generate_module._pipeline_loading |
| original_failure_at = text_generate_module._pipeline_last_failure_at |
| original_last_error = text_generate_module._pipeline_last_error |
| original_runtime_model = text_generate_module._pipeline_runtime_model |
| original_compat_restore = text_generate_module._pipeline_compatibility_restore_active |
| original_cooldown = text_generate_module.PIPELINE_RETRY_COOLDOWN_SECONDS |
| text_generate_module._pipeline = None |
| text_generate_module._pipeline_loading = False |
| text_generate_module._pipeline_last_failure_at = 0.0 |
| text_generate_module._pipeline_last_error = None |
| text_generate_module._pipeline_runtime_model = "" |
| text_generate_module._pipeline_compatibility_restore_active = False |
| try: |
| yield |
| finally: |
| text_generate_module._pipeline = original_pipeline |
| text_generate_module._pipeline_loading = original_loading |
| text_generate_module._pipeline_last_failure_at = original_failure_at |
| text_generate_module._pipeline_last_error = original_last_error |
| text_generate_module._pipeline_runtime_model = original_runtime_model |
| text_generate_module._pipeline_compatibility_restore_active = original_compat_restore |
| text_generate_module.PIPELINE_RETRY_COOLDOWN_SECONDS = original_cooldown |
|
|
|
|
| def test_get_pipeline_starts_background_load_and_returns_none_while_warming_up() -> None: |
| started = Queue() |
| release = Queue() |
|
|
| def fake_build_pipeline() -> str: |
| started.put(True) |
| release.get(timeout=1) |
| return "loaded-pipeline" |
|
|
| with ( |
| _reset_pipeline_runtime(), |
| patch("maris_core.text.generate._build_pipeline", side_effect=fake_build_pipeline), |
| ): |
| assert text_generate_module.get_pipeline() is None |
| assert started.get(timeout=1) is True |
| assert text_generate_module.get_pipeline() is None |
|
|
| release.put(True) |
| deadline = time.monotonic() + 1 |
| pipeline = None |
| while time.monotonic() < deadline: |
| pipeline = text_generate_module.get_pipeline() |
| if pipeline is not None: |
| break |
| time.sleep(0.01) |
|
|
| assert pipeline == "loaded-pipeline" |
|
|
|
|
| def test_get_text_model_readiness_transitions_from_cold_to_warming_up_to_ready() -> None: |
| started = Queue() |
| release = Queue() |
|
|
| def fake_build_pipeline() -> str: |
| started.put(True) |
| release.get(timeout=1) |
| return "loaded-pipeline" |
|
|
| with ( |
| _reset_pipeline_runtime(), |
| patch("maris_core.text.generate._build_pipeline", side_effect=fake_build_pipeline), |
| ): |
| cold_readiness = get_text_model_readiness() |
| assert cold_readiness["ready"] is False |
| assert cold_readiness["state"] == "cold" |
| assert cold_readiness["compatibility_restore_active"] is False |
| assert cold_readiness["model"] |
| warming_up = get_text_model_readiness(start_loading=True) |
| assert warming_up["ready"] is False |
| assert warming_up["state"] == "warming_up" |
| assert started.get(timeout=1) is True |
| assert get_text_model_readiness()["state"] == "warming_up" |
|
|
| release.put(True) |
| deadline = time.monotonic() + 1 |
| readiness: dict[str, Any] | None = None |
| while time.monotonic() < deadline: |
| readiness = get_text_model_readiness() |
| if readiness["ready"]: |
| break |
| time.sleep(0.01) |
|
|
| assert readiness is not None |
| assert readiness["ready"] is True |
| assert readiness["state"] == "ready" |
| assert readiness["compatibility_restore_active"] is False |
| assert readiness["model"] |
|
|
|
|
| def test_get_pipeline_throttles_retries_after_failed_background_load() -> None: |
| attempts = 0 |
|
|
| def fake_build_pipeline() -> Any: |
| nonlocal attempts |
| attempts += 1 |
| return None |
|
|
| with _reset_pipeline_runtime(): |
| text_generate_module.PIPELINE_RETRY_COOLDOWN_SECONDS = 60.0 |
| with patch("maris_core.text.generate._build_pipeline", side_effect=fake_build_pipeline): |
| assert text_generate_module.get_pipeline() is None |
| deadline = time.monotonic() + 1 |
| while text_generate_module._pipeline_loading and time.monotonic() < deadline: |
| time.sleep(0.01) |
|
|
| assert attempts == 1 |
| assert text_generate_module.get_pipeline() is None |
| assert attempts == 1 |
|
|
| text_generate_module._pipeline_last_failure_at = ( |
| time.monotonic() - text_generate_module.PIPELINE_RETRY_COOLDOWN_SECONDS - 1.0 |
| ) |
| assert text_generate_module.get_pipeline() is None |
|
|
| deadline = time.monotonic() + 1 |
| while text_generate_module._pipeline_loading and time.monotonic() < deadline: |
| time.sleep(0.01) |
|
|
| assert attempts == 2 |
|
|
|
|
| def test_get_text_model_readiness_reports_retry_cooldown_after_failed_load() -> None: |
| with _reset_pipeline_runtime(): |
| text_generate_module.PIPELINE_RETRY_COOLDOWN_SECONDS = 60.0 |
| text_generate_module._pipeline_last_failure_at = time.monotonic() |
|
|
| readiness = get_text_model_readiness() |
|
|
| assert readiness["ready"] is False |
| assert readiness["state"] == "retry_cooldown" |
| assert readiness["retry_after_seconds"] >= 1 |
|
|
|
|
| def test_build_pipeline_wraps_runtime_model_in_compatibility_restore() -> None: |
| captured: dict[str, Any] = {} |
|
|
| def fake_pipeline(task: str, *, model: str, device_map: str, trust_remote_code: bool) -> str: |
| captured.update( |
| { |
| "task": task, |
| "model": model, |
| "device_map": device_map, |
| "trust_remote_code": trust_remote_code, |
| } |
| ) |
| return "runtime-pipeline" |
|
|
| @contextmanager |
| def fake_compat_path(model_name: str): |
| captured["requested_model"] = model_name |
| yield "/tmp/maris-compat-restored" |
|
|
| with ( |
| _reset_pipeline_runtime(), |
| patch( |
| "maris_core.text.generate.resolve_text_model", return_value="custom-user/maris-runtime" |
| ), |
| patch.dict(sys.modules, {"transformers": SimpleNamespace(pipeline=fake_pipeline)}), |
| patch("maris_core.text.generate.maris_hf_compatible_path", fake_compat_path), |
| ): |
| runtime_pipeline = text_generate_module._build_pipeline() |
| readiness = text_generate_module.get_text_model_readiness() |
|
|
| assert runtime_pipeline == "runtime-pipeline" |
| assert captured["requested_model"] == "custom-user/maris-runtime" |
| assert captured["model"] == "/tmp/maris-compat-restored" |
| assert readiness["model"] == "custom-user/maris-runtime" |
| assert readiness["compatibility_restore_active"] is True |
|
|
|
|
| def test_resolve_text_model_prefers_runtime_override(monkeypatch) -> None: |
| monkeypatch.setenv("TEXT_MODEL", "MarisUK/maris-ai-text") |
| monkeypatch.setenv("MARIS_RUNTIME_TEXT_MODEL", "Qwen/Qwen2.5-7B-Instruct") |
|
|
| assert resolve_text_model() == "Qwen/Qwen2.5-7B-Instruct" |
|
|
|
|
| def test_resolve_text_model_accepts_generic_huggingface_repo(monkeypatch) -> None: |
| monkeypatch.setenv("TEXT_MODEL", "custom-user/private-text-model") |
| monkeypatch.delenv("MARIS_RUNTIME_TEXT_MODEL", raising=False) |
|
|
| assert resolve_text_model() == "custom-user/private-text-model" |
|
|
|
|
| def test_resolve_text_model_rejects_invalid_runtime_override(monkeypatch) -> None: |
| monkeypatch.setenv("MARIS_RUNTIME_TEXT_MODEL", "not-a-valid-model") |
|
|
| with pytest.raises(RuntimeError): |
| resolve_text_model() |
|
|
|
|
| def test_call_generation_pipeline_clears_max_length_from_generation_config() -> None: |
| captured_kwargs: dict[str, Any] = {} |
|
|
| class FakePipeline: |
| generation_config = GenerationConfig(max_length=20, temperature=0.8) |
|
|
| def __call__(self, messages: list[dict[str, str]], **kwargs: Any) -> list[dict[str, Any]]: |
| nonlocal captured_kwargs |
| captured_kwargs = kwargs |
| return [{"generated_text": [{"role": "assistant", "content": "Sveiki"}]}] |
|
|
| call_generation_pipeline( |
| FakePipeline(), |
| [{"role": "user", "content": "Sveiki"}], |
| max_new_tokens=160, |
| temperature=0.1, |
| ) |
|
|
| generation_config = captured_kwargs["generation_config"] |
| assert "max_new_tokens" not in captured_kwargs |
| assert "temperature" not in captured_kwargs |
| assert generation_config.max_new_tokens == 160 |
| assert generation_config.max_length is None |
| assert generation_config.temperature == 0.1 |
|
|
|
|
| def test_call_generation_pipeline_builds_generation_config_from_model_config() -> None: |
| captured_kwargs: dict[str, Any] = {} |
|
|
| class FakePipeline: |
| model = SimpleNamespace(config=PretrainedConfig()) |
|
|
| def __call__(self, messages: list[dict[str, str]], **kwargs: Any) -> list[dict[str, Any]]: |
| nonlocal captured_kwargs |
| captured_kwargs = kwargs |
| return [{"generated_text": [{"role": "assistant", "content": "Sveiki"}]}] |
|
|
| call_generation_pipeline( |
| FakePipeline(), |
| [{"role": "user", "content": "Sveiki"}], |
| max_new_tokens=64, |
| temperature=0.0, |
| ) |
|
|
| generation_config = captured_kwargs["generation_config"] |
| assert generation_config.max_new_tokens == 64 |
| assert generation_config.max_length is None |
| assert generation_config.do_sample is False |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_endpoint_raises_when_model_is_unavailable() -> None: |
| """Pārbauda graceful fallback, ja modelis nav pieejams.""" |
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=None), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ) as save_conversation, |
| ): |
| response = await generate(GenerateRequest(message="sveiki")) |
|
|
| assert response.model == FALLBACK_MODEL_NAME |
| assert "Pilnais modelis šobrīd nav pieejams" in response.response |
| assert response.tokens_used > 0 |
| assert save_conversation.await_args.kwargs["metadata"]["fallback_used"] is True |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_uses_requested_hf_fallback_model_when_runtime_is_unavailable() -> None: |
| class FakeClient: |
| def __init__(self) -> None: |
| self.called_model: str | None = None |
|
|
| def chat_completion( |
| self, *, model: str, messages: list[dict[str, str]], max_tokens: int, temperature: float |
| ) -> dict[str, Any]: |
| del messages, max_tokens, temperature |
| self.called_model = model |
| return { |
| "choices": [{"message": {"content": "Šī ir īsta fallback atbilde no HF modeļa."}}] |
| } |
|
|
| fake_client = FakeClient() |
| fake_hf_module = SimpleNamespace(InferenceClient=FakeClient) |
| fake_hf_utils = SimpleNamespace(HfHubHTTPError=RuntimeError) |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=None), |
| patch("maris_core.text.generate.create_hf_inference_client", return_value=fake_client), |
| patch.dict( |
| sys.modules, {"huggingface_hub": fake_hf_module, "huggingface_hub.utils": fake_hf_utils} |
| ), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ) as save_conversation, |
| ): |
| response = await generate( |
| GenerateRequest( |
| message="Sveiki", |
| fallback_model="Qwen/Qwen2.5-72B-Instruct", |
| ) |
| ) |
|
|
| assert response.model == "Qwen/Qwen2.5-72B-Instruct" |
| assert response.response == "Šī ir īsta fallback atbilde no HF modeļa." |
| assert fake_client.called_model == "Qwen/Qwen2.5-72B-Instruct" |
| assert save_conversation.await_args.kwargs["metadata"]["fallback_used"] is True |
| assert ( |
| save_conversation.await_args.kwargs["metadata"]["requested_fallback_model"] |
| == "Qwen/Qwen2.5-72B-Instruct" |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_returns_emotional_metadata() -> None: |
| fake_pipeline = lambda messages, max_new_tokens, temperature: [ |
| { |
| "generated_text": messages |
| + [{"role": "assistant", "content": "Sapratu, iesim cauri mierīgi pa soļiem."}], |
| "usage": {"total_tokens": 321}, |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ) as save_conversation, |
| ): |
| response = await generate( |
| GenerateRequest(message="Šis nestrādā un mani tas kaitina", profile="general") |
| ) |
|
|
| assert response.response == "Sapratu, iesim cauri mierīgi pa soļiem." |
| assert response.detected_emotion == "frustrated" |
| assert response.response_style == "calm_reassuring_step_by_step" |
| assert response.emotion_confidence >= 0.6 |
| assert response.tokens_used == 321 |
| assert response.request_id |
| assert response.session_id.startswith("ephemeral-") |
| assert response.prompt_messages >= 2 |
| save_conversation.assert_awaited_once() |
| metadata = save_conversation.await_args.kwargs["metadata"] |
| assert metadata["request_id"] == response.request_id |
| assert metadata["session_id"] == response.session_id |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_injects_relevant_memory_context() -> None: |
| captured_messages: list[dict[str, str]] = [] |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": messages |
| + [{"role": "assistant", "content": "Atceros iepriekšējo kontekstu."}] |
| } |
| ] |
|
|
| memory = ConversationMemoryStore() |
| memory.remember_message("session-42", "assistant", "Iepriekš runājām par API retry stratēģiju.") |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch("maris_core.text.generate.memory_store", memory), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate( |
| GenerateRequest( |
| message="Turpinām par retry API klientu", |
| session_id="session-42", |
| ) |
| ) |
|
|
| assert response.response == "Atceros iepriekšējo kontekstu." |
| assert response.memory_matches >= 1 |
| assert any( |
| message["role"] == "system" and "Saistītā atmiņa" in message["content"] |
| for message in captured_messages |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_injects_user_focus_context() -> None: |
| captured_messages: list[dict[str, str]] = [] |
| memory = ConversationMemoryStore() |
| memory.remember_message( |
| "session-focus", |
| "user", |
| "Es gribu, lai mans AI asistents mācās no iepriekšējām sarunām.", |
| ) |
| memory.remember_message( |
| "session-focus", |
| "user", |
| "Man svarīgi, lai atbildes paliek pamatotas ar faktiem.", |
| ) |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": messages |
| + [{"role": "assistant", "content": "Balstos tavā ilgtermiņa fokusā un mērķos."}] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch("maris_core.text.generate.memory_store", memory), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ) as save_conversation, |
| ): |
| response = await generate( |
| GenerateRequest( |
| message="Palīdzi man pietuvoties īstākam AI", |
| session_id="session-focus", |
| ) |
| ) |
|
|
| assert response.response == "Balstos tavā ilgtermiņa fokusā un mērķos." |
| assert any( |
| message["role"] == "system" and "Lietotāja ilgtermiņa fokuss" in message["content"] |
| for message in captured_messages |
| ) |
| metadata = save_conversation.await_args.kwargs["metadata"] |
| assert metadata["user_focus_items"] == 2 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_injects_active_thread_context() -> None: |
| captured_messages: list[dict[str, str]] = [] |
| memory = ConversationMemoryStore() |
| memory.remember_message( |
| "session-thread", |
| "user", |
| "Kā man uzbūvēt uzticamu AI asistentu ar ilgtermiņa atmiņu?", |
| ) |
| memory.remember_message( |
| "session-thread", |
| "user", |
| "Turpinām ar nākamajiem 3 soļiem un prioritātēm.", |
| ) |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": messages |
| + [ |
| { |
| "role": "assistant", |
| "content": "Turpinu aktīvos pavedienus no iepriekšējās sarunas.", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch("maris_core.text.generate.memory_store", memory), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ) as save_conversation, |
| ): |
| response = await generate( |
| GenerateRequest( |
| message="Palīdzi man turpināt šo virzienu", |
| session_id="session-thread", |
| ) |
| ) |
|
|
| assert response.response == "Turpinu aktīvos pavedienus no iepriekšējās sarunas." |
| assert any( |
| message["role"] == "system" and "Aktīvie pavedieni šai sesijai" in message["content"] |
| for message in captured_messages |
| ) |
| metadata = save_conversation.await_args.kwargs["metadata"] |
| assert metadata["active_thread_items"] == 2 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_injects_vision_context_and_stores_it_in_memory() -> None: |
| captured_messages: list[dict[str, str]] = [] |
| memory = ConversationMemoryStore() |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": messages |
| + [ |
| { |
| "role": "assistant", |
| "content": "Attēlā redzams monitora dashboard ar kļūdu paneli.", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch("maris_core.text.generate.memory_store", memory), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate( |
| GenerateRequest( |
| message="Pastāsti ko redzi šajā screenshot", |
| session_id="vision-session", |
| vision_context={ |
| "summary": "Screenshot rāda monitora dashboard ar sarkanu incident alert.", |
| "source": "upload", |
| "model": "facebook/detr-resnet-50", |
| "detections": 3, |
| "width": 1024, |
| "height": 768, |
| }, |
| ) |
| ) |
|
|
| assert response.response == "Attēlā redzams monitora dashboard ar kļūdu paneli." |
| assert any( |
| message["role"] == "system" and "Vizuālais konteksts" in message["content"] |
| for message in captured_messages |
| ) |
| matches = memory.retrieve_relevant_context("vision-session", "incident alert") |
| assert matches |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_uses_workspace_tools_for_repo_grounding(tmp_path) -> None: |
| captured_messages: list[dict[str, str]] = [] |
| docs_dir = tmp_path / "docs" |
| docs_dir.mkdir() |
| (docs_dir / "README.md").write_text( |
| "# Maris\nCanonical health endpoint is /health and /ready is compatibility only.\n", |
| encoding="utf-8", |
| ) |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": messages |
| + [ |
| { |
| "role": "assistant", |
| "content": "Repo dokumentācija rāda, ka kanoniskais health endpoints ir /health.", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch("maris_core.text.tools.WORKSPACE_ROOT", tmp_path), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ) as save_conversation, |
| ): |
| response = await generate( |
| GenerateRequest(message="Kas README rakstīts par health endpoint repo dokumentācijā?") |
| ) |
|
|
| assert response.tool_trace is not None |
| assert response.tool_trace.mode in {"tool_augmented", "multi_step"} |
| assert response.tool_trace.steps |
| assert any( |
| message["role"] == "system" and "Tool grounding context:" in message["content"] |
| for message in captured_messages |
| ) |
| metadata = save_conversation.await_args.kwargs["metadata"] |
| assert metadata["tool_steps"] >= 1 |
| assert metadata["tool_mode"] in {"tool_augmented", "multi_step"} |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_execute_tool_trace_follows_web_search_with_fetch() -> None: |
| def handler(request: httpx.Request) -> httpx.Response: |
| if request.url.host == "api.duckduckgo.com": |
| return httpx.Response( |
| 200, |
| json={ |
| "Heading": "Maris release notes", |
| "AbstractText": "", |
| "RelatedTopics": [ |
| { |
| "Text": "Maris release notes - Latest changes", |
| "FirstURL": "https://example.com/maris-release", |
| } |
| ], |
| }, |
| ) |
| if request.url.host == "example.com": |
| return httpx.Response( |
| 200, |
| text=( |
| "<html><head><title>Maris Release</title></head>" |
| "<body><main>Latest Maris release adds grounded tool orchestration.</main></body></html>" |
| ), |
| headers={"content-type": "text/html; charset=utf-8"}, |
| ) |
| raise AssertionError(f"Unexpected URL: {request.url}") |
|
|
| plan = plan_tool_use("Kas ir jaunākais Maris release?") |
| assert plan is not None |
|
|
| async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: |
| trace = await execute_tool_trace( |
| plan, |
| message="Kas ir jaunākais Maris release?", |
| client=client, |
| max_steps=4, |
| ) |
|
|
| assert [step.name for step in trace.steps[:2]] == ["web_search", "web_fetch"] |
| assert any(source.kind == "web_fetch" for source in trace.grounding_sources) |
| assert any( |
| "grounded tool orchestration" in (source.snippet or "") |
| for source in trace.grounding_sources |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_reads_exact_workspace_path_and_adds_grounding_message(tmp_path) -> None: |
| captured_messages: list[dict[str, str]] = [] |
| docs_dir = tmp_path / "docs" |
| docs_dir.mkdir() |
| (docs_dir / "guide.md").write_text( |
| "# Deploy\nUse /ready for platform readiness checks.\n", |
| encoding="utf-8", |
| ) |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": messages |
| + [ |
| { |
| "role": "assistant", |
| "content": "docs/guide.md rāda, ka readiness checks jābalsta uz /ready.", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch("maris_core.text.tools.WORKSPACE_ROOT", tmp_path), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate( |
| GenerateRequest(message="Ko docs/guide.md saka par readiness checks?", max_tool_steps=6) |
| ) |
|
|
| assert response.tool_trace is not None |
| assert any(step.name == "workspace_read" for step in response.tool_trace.steps) |
| assert any( |
| message["role"] == "system" |
| and "docs/guide.md" in message["content"] |
| and "Tool grounding context:" in message["content"] |
| for message in captured_messages |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_uses_workspace_tools_for_repo_debug_prompt(tmp_path) -> None: |
| captured_messages: list[dict[str, str]] = [] |
| backend_dir = tmp_path / "backend-rust" / "src" / "api" |
| frontend_dir = tmp_path / "frontend" / "app" / "chat" |
| backend_dir.mkdir(parents=True) |
| frontend_dir.mkdir(parents=True) |
| (backend_dir / "chat.rs").write_text( |
| 'event: complete\nlet route = "/api/chat/stream";\n', |
| encoding="utf-8", |
| ) |
| (frontend_dir / "page.tsx").write_text( |
| "if (event.type === 'complete') finalizeStream();\n", |
| encoding="utf-8", |
| ) |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| del max_new_tokens, temperature |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": messages |
| + [ |
| { |
| "role": "assistant", |
| "content": "Abi faili rāda, ka complete event ir jāsaskaņo starp backend un frontend parseri.", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch("maris_core.text.tools.WORKSPACE_ROOT", tmp_path), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate( |
| GenerateRequest( |
| message=( |
| "Debug SSE mismatch starp backend-rust/src/api/chat.rs un " |
| "frontend/app/chat/page.tsx, balstoties uz esošo repo kodu." |
| ) |
| ) |
| ) |
|
|
| assert response.tool_trace is not None |
| assert len(response.tool_trace.grounding_sources) >= 2 |
| assert any(step.name == "workspace_search" for step in response.tool_trace.steps) |
| assert any( |
| message["role"] == "system" |
| and "backend-rust/src/api/chat.rs" in message["content"] |
| and "frontend/app/chat/page.tsx" in message["content"] |
| for message in captured_messages |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_applies_selected_persona_to_prompt_and_response() -> None: |
| captured_messages: list[dict[str, str]] = [] |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": messages |
| + [ |
| { |
| "role": "assistant", |
| "content": "Skatos uz to kā sistēmu un prioritāšu jautājumu.", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate( |
| GenerateRequest( |
| message="Palīdzi ar produkta roadmap", |
| profile="general", |
| persona_id="strategist", |
| ) |
| ) |
|
|
| assert response.persona_id == "strategist" |
| assert response.persona_title == "Systems Strategist" |
| assert "Aktīvā persona: Systems Strategist." in captured_messages[0]["content"] |
| assert "Assistant runtime contract:" in captured_messages[1]["content"] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_adds_coding_contract_for_coder_requests() -> None: |
| captured_messages: list[dict[str, str]] = [] |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| del max_new_tokens, temperature |
| captured_messages = messages |
| return [{"generated_text": "```python\nprint('ok')\n```"}] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| await generate( |
| GenerateRequest( |
| message="Uzraksti Python helperi ar validāciju un testiem", |
| profile="coder", |
| ) |
| ) |
|
|
| assert "Coding response contract:" in captured_messages[1]["content"] |
| assert "edge cases" in captured_messages[1]["content"] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_adds_session_summary_for_longer_persona_continuity() -> None: |
| captured_messages: list[dict[str, str]] = [] |
| memory = ConversationMemoryStore() |
| memory.remember_message( |
| "session-77", "user", "Mēs būvējam incident response roadmap komandas līmenī." |
| ) |
| memory.remember_message( |
| "session-77", |
| "assistant", |
| "Tu gribēji prioritizēt alerting, ownership un postmortem procesu.", |
| ) |
|
|
| def fake_pipeline(messages, max_new_tokens, temperature): |
| nonlocal captured_messages |
| captured_messages = messages |
| return [{"generated_text": "Turpinām ar strukturētu roadmap."}] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch("maris_core.text.generate.memory_store", memory), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate( |
| GenerateRequest( |
| message="Kas ir nākamās 3 prioritātes?", |
| session_id="session-77", |
| persona_id="strategist", |
| ) |
| ) |
|
|
| assert response.response == "Turpinām ar strukturētu roadmap." |
| assert any( |
| message["role"] == "system" |
| and "Sesijas kopsavilkums ilgākai konsekvencei" in message["content"] |
| for message in captured_messages |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_handles_string_output_with_token_estimation() -> None: |
| fake_pipeline = lambda messages, max_new_tokens, temperature: [ |
| {"generated_text": "Profesionāla atbilde bez čata masīva."} |
| ] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate(GenerateRequest(message="Dod īsu atbildi")) |
|
|
| assert response.response == "Profesionāla atbilde bez čata masīva." |
| assert response.tokens_used > 0 |
|
|
|
|
| def test_sanitize_response_text_removes_prompt_echo_and_assistant_prefix() -> None: |
| messages = [ |
| {"role": "system", "content": "Tu esi Maris AI."}, |
| {"role": "user", "content": "Dod īsu atbildi"}, |
| ] |
|
|
| cleaned = _sanitize_response_text( |
| "System: Tu esi Maris AI.\nUser: Dod īsu atbildi\nAssistant: Precīza atbilde.", |
| messages, |
| ) |
|
|
| assert cleaned == "Precīza atbilde." |
|
|
|
|
| @pytest.mark.asyncio |
| @pytest.mark.parametrize("error_type", [TypeError, ValueError, AttributeError]) |
| async def test_generate_falls_back_to_prompt_text_for_non_chat_pipelines( |
| error_type: type[Exception], |
| ) -> None: |
| calls: list[tuple[object, dict[str, object]]] = [] |
|
|
| class FakePipeline: |
| def __call__(self, payload: object, **kwargs: Any) -> list[Mapping[str, str]]: |
| calls.append((payload, dict(kwargs))) |
| if isinstance(payload, list): |
| raise error_type("chat messages are not supported") |
| assert isinstance(payload, str) |
| assert "User: Izveido īsu atbildi" in payload |
| assert payload.endswith("Assistant:") |
| assert kwargs["return_full_text"] is False |
| return [{"generated_text": "Īsa atbilde bez chat template kļūdas."}] |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=FakePipeline()), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate(GenerateRequest(message="Izveido īsu atbildi")) |
|
|
| assert response.response == "Īsa atbilde bez chat template kļūdas." |
| assert len(calls) == 2 |
| assert isinstance(calls[0][0], list) |
| assert isinstance(calls[1][0], str) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_falls_back_to_runtime_response_when_output_payload_is_invalid() -> None: |
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=lambda *args, **kwargs: [{}]), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ) as save_conversation, |
| ): |
| response = await generate(GenerateRequest(message="Dod man virzienu nākamajam solim")) |
|
|
| assert response.model == FALLBACK_MODEL_NAME |
| assert "drošu rezerves atbildi" in response.response |
| assert save_conversation.await_args.kwargs["metadata"]["fallback_used"] is True |
|
|
|
|
| def test_generate_request_rejects_invalid_message_and_history() -> None: |
| with pytest.raises(ValidationError): |
| GenerateRequest(message=" ") |
|
|
| with pytest.raises(ValidationError): |
| GenerateRequest(message="Derīga ziņa", history=[{"role": "tool", "content": "x"}]) |
|
|
|
|
| def test_generate_request_allows_large_max_new_tokens() -> None: |
| req = GenerateRequest(message="Uzraksti garu profesionālu atbildi", max_new_tokens=20_000) |
|
|
| assert req.max_new_tokens == 20_000 |
|
|
|
|
| def test_generate_request_uses_large_default_max_new_tokens() -> None: |
| req = GenerateRequest(message="Dod pilnu risinājumu") |
|
|
| assert req.max_new_tokens == DEFAULT_MAX_NEW_TOKENS |
|
|
|
|
| def test_generate_request_accepts_configurable_max_tool_steps() -> None: |
| req = GenerateRequest(message="Izpildi ar rīkiem", max_tool_steps=18) |
|
|
| assert req.max_tool_steps == 18 |
|
|
|
|
| def test_plan_tool_use_detects_external_verification_requests() -> None: |
| trace = plan_tool_use("Pārbaudi oficiālos avotos, vai Anthropic Claude 4 joprojām ir aktuāls.") |
|
|
| assert trace is not None |
| assert trace.mode in {"tool_augmented", "multi_step"} |
|
|
|
|
| def test_generate_stream_endpoint_uses_fallback_stream_when_model_is_unavailable() -> None: |
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=None), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| client = TestClient(_build_text_app()) |
| with client.stream( |
| "POST", "/v1/text/generate/stream", json={"message": "Sveiki"} |
| ) as response: |
| body = "".join(response.iter_text()) |
|
|
| assert response.status_code == 200 |
| assert "event: delta" in body |
| assert "event: complete" in body |
| assert FALLBACK_MODEL_NAME in body |
|
|
|
|
| def test_generate_stream_endpoint_streams_real_model_deltas() -> None: |
| captured_generation_kwargs: dict[str, Any] = {} |
|
|
| class FakeTensor: |
| def to(self, device: str) -> FakeTensor: |
| return self |
|
|
| class FakeTokenizer: |
| eos_token_id = 7 |
| pad_token_id = 7 |
|
|
| def __call__(self, prompt: str, return_tensors: str) -> dict[str, FakeTensor]: |
| assert "Assistant:" in prompt |
| assert return_tensors == "pt" |
| return {"input_ids": FakeTensor()} |
|
|
| class FakeModel: |
| device = "cpu" |
|
|
| def generate(self, **kwargs: Any) -> None: |
| nonlocal captured_generation_kwargs |
| captured_generation_kwargs = kwargs |
| streamer = kwargs["streamer"] |
| streamer.put("Sveiki ") |
| streamer.put("no straumes!") |
| streamer.end() |
|
|
| class FakePipeline: |
| tokenizer = FakeTokenizer() |
| model = FakeModel() |
|
|
| class FakeStoppingCriteria: |
| def __call__(self, input_ids: Any, scores: Any, **kwargs: Any) -> bool: |
| return False |
|
|
| class FakeStoppingCriteriaList(list): |
| pass |
|
|
| class FakeTextIteratorStreamer: |
| def __init__(self, tokenizer: Any, skip_prompt: bool, skip_special_tokens: bool) -> None: |
| self.queue: Queue[str | None] = Queue() |
|
|
| def put(self, value: str) -> None: |
| self.queue.put(value) |
|
|
| def end(self) -> None: |
| self.queue.put(None) |
|
|
| def __iter__(self) -> FakeTextIteratorStreamer: |
| return self |
|
|
| def __next__(self) -> str: |
| item = self.queue.get(timeout=1) |
| if item is None: |
| raise StopIteration |
| return item |
|
|
| fake_transformers = SimpleNamespace( |
| GenerationConfig=GenerationConfig, |
| StoppingCriteria=FakeStoppingCriteria, |
| StoppingCriteriaList=FakeStoppingCriteriaList, |
| TextIteratorStreamer=FakeTextIteratorStreamer, |
| ) |
|
|
| with ( |
| patch("maris_core.text.generate.get_pipeline", return_value=FakePipeline()), |
| patch("maris_core.text.generate.resolve_text_model", return_value="MarisUK/test-model"), |
| patch.dict("sys.modules", {"transformers": fake_transformers}), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_conversation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| client = TestClient(_build_text_app()) |
| with client.stream( |
| "POST", "/v1/text/generate/stream", json={"message": "Sveiki"} |
| ) as response: |
| body = "".join(response.iter_text()) |
|
|
| assert response.status_code == 200 |
| assert '{"delta": "Sveiki "}' in body |
| assert '{"delta": "no straumes!"}' in body |
| assert "event: complete" in body |
| assert "MarisUK/test-model" in body |
| generation_config = captured_generation_kwargs["generation_config"] |
| assert generation_config.eos_token_id == 7 |
| assert generation_config.pad_token_id == 7 |
| assert "eos_token_id" not in captured_generation_kwargs |
| assert "pad_token_id" not in captured_generation_kwargs |
|
|