| """Tests koda ģenerēšanai.""" |
|
|
| import sys |
| from pathlib import Path |
| from typing import Any |
| from unittest.mock import AsyncMock, patch |
|
|
| import pytest |
| from fastapi import HTTPException |
|
|
| from maris_core.code.generate_code import ( |
| CodeRequest, |
| FixCodeRequest, |
| ProjectFile, |
| _detect_stack, |
| _extract_code_block, |
| _extract_project_files, |
| fix_code, |
| generate_code, |
| ) |
|
|
|
|
| def test_extract_code_block_with_fences() -> None: |
| text = "Šeit ir kods:\n```\nprint('hello')\n```\nPaskaidrojums." |
| code, explanation = _extract_code_block(text, "Python") |
| assert "print" in code |
| assert "Paskaidrojums" in explanation |
|
|
|
|
| def test_extract_code_block_no_fences() -> None: |
| text = "print('hello world')" |
| code, explanation = _extract_code_block(text, "Python") |
| assert "print" in code |
|
|
|
|
| def test_extract_project_files_from_structured_json_payload() -> None: |
| text = """```json |
| { |
| "explanation": "Pilns kalkulatora projekts.", |
| "entrypoint": "index.html", |
| "files": [ |
| {"path": "index.html", "content": "<html></html>"}, |
| {"path": "assets/app.js", "content": "console.log('ok');"} |
| ] |
| } |
| ```""" |
|
|
| files, entrypoint, explanation = _extract_project_files(text, language="HTML/CSS/JavaScript") |
|
|
| assert files == [ |
| ProjectFile(path="index.html", content="<html></html>", absolute_path=None), |
| ProjectFile(path="assets/app.js", content="console.log('ok');", absolute_path=None), |
| ] |
| assert entrypoint == "index.html" |
| assert explanation == "Pilns kalkulatora projekts." |
|
|
|
|
| def test_detect_stack_prefers_nextjs_from_prompt() -> None: |
| detected = _detect_stack( |
| "Uztaisi Next.js dashboard ar app router un TypeScript", |
| "Python", |
| None, |
| ) |
|
|
| assert detected == "nextjs" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_code_requires_text_model() -> None: |
| with ( |
| patch("maris_core.code.generate_code.get_pipeline", return_value=None), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| req = CodeRequest(prompt="Hello world skripts", language="Python") |
| with pytest.raises(HTTPException) as exc_info: |
| await generate_code(req) |
| assert exc_info.value.status_code == 503 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_code_uses_requested_hf_fallback_model_when_text_runtime_is_unavailable() -> ( |
| None |
| ): |
| class FakeClient: |
| def chat_completion( |
| self, |
| *, |
| model: str, |
| messages: list[dict[str, str]], |
| max_tokens: int, |
| temperature: float, |
| ) -> dict[str, Any]: |
| del messages, max_tokens, temperature |
| assert model == "Qwen/Qwen2.5-Coder-32B-Instruct" |
| return { |
| "choices": [ |
| { |
| "message": { |
| "content": "```python\ndef normalize_email(value: str) -> str:\n return value.strip().lower()\n```" |
| } |
| } |
| ] |
| } |
|
|
| fake_hf_module = type("FakeHFModule", (), {"InferenceClient": object})() |
| fake_hf_utils = type("FakeHFUtils", (), {"HfHubHTTPError": RuntimeError})() |
|
|
| with ( |
| patch("maris_core.code.generate_code.get_pipeline", return_value=None), |
| patch("maris_core.text.generate.create_hf_inference_client", return_value=FakeClient()), |
| patch.dict( |
| sys.modules, {"huggingface_hub": fake_hf_module, "huggingface_hub.utils": fake_hf_utils} |
| ), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate_code( |
| CodeRequest( |
| prompt="Uzraksti Python helperi normalize_email", |
| language="Python", |
| fallback_model="Qwen/Qwen2.5-Coder-32B-Instruct", |
| ) |
| ) |
|
|
| assert "normalize_email" in response.code |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_fix_code_delegates() -> None: |
| with ( |
| patch("maris_core.code.generate_code.get_pipeline", return_value=None), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| req = FixCodeRequest(code="prit('hello')", error_message="NameError", language="Python") |
| with pytest.raises(HTTPException) as exc_info: |
| await fix_code(req) |
| assert exc_info.value.status_code == 503 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_code_passes_large_max_new_tokens_to_pipeline() -> None: |
| captured_max_new_tokens: int | None = None |
|
|
| def fake_pipeline( |
| messages: list[dict[str, Any]], *, max_new_tokens: int, temperature: float |
| ) -> list[dict[str, list[dict[str, str]]]]: |
| nonlocal captured_max_new_tokens |
| del messages, temperature |
| captured_max_new_tokens = max_new_tokens |
| return [{"generated_text": [{"role": "assistant", "content": "print('hello world')"}]}] |
|
|
| with ( |
| patch("maris_core.code.generate_code.get_pipeline", return_value=fake_pipeline), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate_code( |
| CodeRequest( |
| prompt="Uzraksti pilnu Python servisu", |
| language="Python", |
| max_new_tokens=20_000, |
| ) |
| ) |
|
|
| assert captured_max_new_tokens == 20_000 |
| assert response.code == "print('hello world')" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_code_uses_stronger_engineering_system_prompt() -> None: |
| captured_messages: list[dict[str, Any]] = [] |
|
|
| def fake_pipeline( |
| messages: list[dict[str, Any]], *, max_new_tokens: int, temperature: float |
| ) -> list[dict[str, list[dict[str, str]]]]: |
| nonlocal captured_messages |
| del max_new_tokens, temperature |
| captured_messages = messages |
| return [{"generated_text": [{"role": "assistant", "content": "print('hello world')"}]}] |
|
|
| with ( |
| patch("maris_core.code.generate_code.get_pipeline", return_value=fake_pipeline), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| await generate_code(CodeRequest(prompt="Uzraksti Python skriptu", language="Python")) |
|
|
| assert "production-ready" in captured_messages[0]["content"] |
| assert "edge cases" in captured_messages[0]["content"] |
| assert "izpildāmu artefaktu" in captured_messages[0]["content"] |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_code_materializes_workspace_artifacts_from_structured_payload( |
| tmp_path, |
| ) -> None: |
| def fake_pipeline( |
| messages: list[dict[str, Any]], *, max_new_tokens: int, temperature: float |
| ) -> list[dict[str, list[dict[str, str]]]]: |
| del messages, max_new_tokens, temperature |
| return [ |
| { |
| "generated_text": [ |
| { |
| "role": "assistant", |
| "content": """```json |
| { |
| "explanation": "Pilns kalkulatora projekts.", |
| "entrypoint": "index.html", |
| "files": [ |
| {"path": "index.html", "content": "<!doctype html><title>Kalkulators</title>"}, |
| {"path": "assets/app.js", "content": "console.log('calc');"} |
| ] |
| } |
| ```""", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.code.generate_code.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.code.generate_code.WORKSPACE_ARTIFACT_ROOT", tmp_path), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate_code( |
| CodeRequest(prompt="Uzprogrammē kalkulatora projektu", language="HTML/CSS/JavaScript") |
| ) |
|
|
| bundle_path = Path(response.bundle_path or "") |
| assert response.entrypoint == "index.html" |
| assert response.workspace_artifact_dir is not None |
| assert response.detected_stack == "HTML/CSS/JavaScript" |
| assert Path(response.workspace_artifact_dir).exists() |
| assert [file.path for file in response.files] == ["index.html", "assets/app.js"] |
| assert response.files[0].absolute_path is not None |
| assert bundle_path.exists() |
| assert ( |
| Path(response.files[0].absolute_path or "") |
| .read_text(encoding="utf-8") |
| .startswith("<!doctype html>") |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_code_auto_scaffolds_nextjs_project_from_single_code_block(tmp_path) -> None: |
| def fake_pipeline( |
| messages: list[dict[str, Any]], *, max_new_tokens: int, temperature: float |
| ) -> list[dict[str, list[dict[str, str]]]]: |
| del messages, max_new_tokens, temperature |
| return [ |
| { |
| "generated_text": [ |
| { |
| "role": "assistant", |
| "content": """```tsx |
| export default function HomePage() { |
| return <main>Analytics dashboard</main>; |
| } |
| ```""", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.code.generate_code.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.code.generate_code.WORKSPACE_ARTIFACT_ROOT", tmp_path), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate_code( |
| CodeRequest( |
| prompt="Uztaisi Next.js app router dashboard ar TypeScript", |
| language="Python", |
| ) |
| ) |
|
|
| paths = [file.path for file in response.files] |
| assert response.detected_stack == "Next.js (TypeScript)" |
| assert response.language == "Next.js (TypeScript)" |
| assert response.entrypoint == "app/page.tsx" |
| assert "app/page.tsx" in paths |
| assert "package.json" in paths |
| assert Path(response.bundle_path or "").exists() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_generate_code_uses_repo_aware_entrypoint_for_existing_project(tmp_path) -> None: |
| project_root = tmp_path / "existing-project" |
| (project_root / "src").mkdir(parents=True) |
| (project_root / "pyproject.toml").write_text( |
| "[project]\nname = 'demo'\nversion = '0.1.0'\n", |
| encoding="utf-8", |
| ) |
| (project_root / "src/main.py").write_text( |
| "def main() -> None:\n print('old')\n", encoding="utf-8" |
| ) |
|
|
| captured_messages: list[dict[str, Any]] = [] |
|
|
| def fake_pipeline( |
| messages: list[dict[str, Any]], *, max_new_tokens: int, temperature: float |
| ) -> list[dict[str, list[dict[str, str]]]]: |
| nonlocal captured_messages |
| del max_new_tokens, temperature |
| captured_messages = messages |
| return [ |
| { |
| "generated_text": [ |
| { |
| "role": "assistant", |
| "content": """```python |
| def main() -> None: |
| print('updated') |
| ```""", |
| } |
| ] |
| } |
| ] |
|
|
| with ( |
| patch("maris_core.code.generate_code.get_pipeline", return_value=fake_pipeline), |
| patch("maris_core.code.generate_code.WORKSPACE_ARTIFACT_ROOT", tmp_path / "artifacts"), |
| patch( |
| "maris_core.utils.hf_integration.HFIntegration.save_generation", |
| new_callable=AsyncMock, |
| ), |
| ): |
| response = await generate_code( |
| CodeRequest( |
| prompt="Salabo esošo servisu un atjauno starta loģiku", |
| language="Python", |
| repo_path=str(project_root), |
| ) |
| ) |
|
|
| assert response.repo_path == str(project_root) |
| assert response.detected_stack == "Python" |
| assert response.entrypoint == "src/main.py" |
| assert response.files[0].path == "src/main.py" |
| assert "Repo sakne" in captured_messages[1]["content"] |
| assert "src/main.py" in captured_messages[1]["content"] |
|
|