File size: 3,782 Bytes
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efdd22e
 
 
 
 
0da0699
 
 
efdd22e
 
 
 
 
 
 
0da0699
efdd22e
 
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# backend/tests/conftest.py
# Shared fixtures for all PersonaBot backend tests.
# Sets all required env vars so Settings() never fails with a missing-field error.
# Tests run against no real external services — every dependency is mocked.

import os
import time
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from jose import jwt

# Set env before any app import so pydantic-settings picks them up.
os.environ.setdefault("ENVIRONMENT", "test")
os.environ.setdefault("LLM_PROVIDER", "groq")
os.environ.setdefault("GROQ_API_KEY", "gsk_test_key_not_real")
os.environ.setdefault("QDRANT_URL", "http://localhost:6333")
os.environ.setdefault("JWT_SECRET", "test-secret-32-chars-long-0000000")
os.environ.setdefault("ALLOWED_ORIGIN", "http://localhost:3000")
os.environ.setdefault("EMBEDDER_URL", "http://localhost:7860")
os.environ.setdefault("RERANKER_URL", "http://localhost:7861")
os.environ.setdefault("DB_PATH", "/tmp/personabot_test.db")

TEST_JWT_SECRET = os.environ["JWT_SECRET"]
TEST_ALGORITHM = "HS256"

TEST_SESSION_ID = "a1b2c3d4-e5f6-4789-8abc-def012345678"


def make_jwt(secret: str = TEST_JWT_SECRET, exp_offset: int = 3600, **extra) -> str:
    """Create a signed JWT for use in test requests."""
    payload = {"sub": "test-user", "exp": int(time.time()) + exp_offset, **extra}
    return jwt.encode(payload, secret, algorithm=TEST_ALGORITHM)


@pytest.fixture
def valid_token() -> str:
    return make_jwt()


@pytest.fixture
def expired_token() -> str:
    return make_jwt(exp_offset=-1)  # already expired


@pytest.fixture
def wrong_secret_token() -> str:
    return make_jwt(secret="completely-different-secret-0000")


@pytest.fixture
def app_client():
    """TestClient with a mocked app.state so no real services are started."""
    # Clear the lru_cache so test env vars are used
    from app.core.config import get_settings
    get_settings.cache_clear()

    mock_pipeline = MagicMock()

    async def fake_astream(state, stream_mode=None):
        # Support the new stream_mode=["custom", "updates"] tuple format used by chat.py.
        if isinstance(stream_mode, list):
            yield ("custom", {"type": "status", "label": "Checking your question"})
            yield ("updates", {"guard": {"guard_passed": True}})
            # Fix 1: enumerate_query node runs after guard on every request.
            # Non-enumeration queries set is_enumeration_query=False and pass through.
            yield ("updates", {"enumerate_query": {"is_enumeration_query": False}})
            yield ("updates", {"cache": {"cached": False}})
            yield ("custom", {"type": "status", "label": "Thinking about your question directly..."})
            yield ("custom", {"type": "token", "text": "I built TextOps."})
            yield ("updates", {"generate": {"answer": "I built TextOps.", "sources": []}})
        else:
            # Fallback for any code that still calls astream without stream_mode.
            yield {"guard": {"guard_passed": True}}
            yield {"enumerate_query": {"is_enumeration_query": False}}
            yield {"cache": {"cached": False}}
            yield {"generate": {"answer": "I built TextOps.", "sources": []}}

    mock_pipeline.astream = fake_astream

    # Patch the lifespan so TestClient doesn't try to connect to Qdrant or HF Spaces
    with patch("app.main.build_pipeline", return_value=mock_pipeline), \
         patch("app.main.QdrantClient"), \
         patch("app.services.embedder.Embedder"), \
         patch("app.services.reranker.Reranker"):
        from app.main import create_app
        app = create_app()
        app.state.pipeline = mock_pipeline
        with TestClient(app, raise_server_exceptions=True) as client:
            yield client