Spaces:
Sleeping
Sleeping
File size: 2,640 Bytes
64d7fdf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | import pytest
from app.core.pipeline import rag_pipeline
@pytest.mark.integration
class TestRAGPipeline:
@pytest.mark.asyncio
async def test_pipeline_initialization(self):
assert rag_pipeline.retriever is not None
assert rag_pipeline.reranker is not None
assert rag_pipeline.generator is not None
assert rag_pipeline.cache is not None
assert rag_pipeline.memory is not None
@pytest.mark.asyncio
async def test_generate_without_rag(self, clear_cache):
query = "What is 2 + 2?"
response = await rag_pipeline.generate(
query=query,
use_context=False
)
assert response is not None
assert isinstance(response, str)
assert len(response) > 0
@pytest.mark.asyncio
@pytest.mark.slow
async def test_generate_with_rag(self, clear_cache):
query = "Explain the attention mechanism"
try:
response = await rag_pipeline.generate(
query=query,
use_context=True
)
assert response is not None
assert isinstance(response, str)
assert len(response) > 0
except Exception as e:
pytest.skip(f"RAG test skipped: {str(e)}")
@pytest.mark.asyncio
@pytest.mark.slow
async def test_streaming_response(self, clear_cache):
query = "What is machine learning?"
chunks = []
async for chunk in rag_pipeline.stream(
query=query,
use_context=False
):
chunks.append(chunk)
assert len(chunks) > 0
full_response = "".join(chunks)
assert len(full_response) > 0
@pytest.mark.asyncio
async def test_pipeline_with_session(self, clear_cache):
session_id = "test-session-123"
response1 = await rag_pipeline.generate(
query="My name is Alice",
session_id=session_id,
use_context=False
)
assert response1 is not None
messages = rag_pipeline.memory.get_messages(session_id)
assert len(messages) >= 2
@pytest.mark.asyncio
async def test_cache_integration(self, clear_cache):
query = "What is Python?"
response1 = await rag_pipeline.generate(
query=query,
use_context=False
)
response2 = await rag_pipeline.generate(
query=query,
use_context=False
)
assert response1 == response2
|