File size: 2,640 Bytes
64d7fdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pytest
from app.core.pipeline import rag_pipeline


@pytest.mark.integration
class TestRAGPipeline:
    
    @pytest.mark.asyncio
    async def test_pipeline_initialization(self):
        assert rag_pipeline.retriever is not None
        assert rag_pipeline.reranker is not None
        assert rag_pipeline.generator is not None
        assert rag_pipeline.cache is not None
        assert rag_pipeline.memory is not None
    
    @pytest.mark.asyncio
    async def test_generate_without_rag(self, clear_cache):
        query = "What is 2 + 2?"
        
        response = await rag_pipeline.generate(
            query=query,
            use_context=False
        )
        
        assert response is not None
        assert isinstance(response, str)
        assert len(response) > 0
    
    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_generate_with_rag(self, clear_cache):
        query = "Explain the attention mechanism"
        
        try:
            response = await rag_pipeline.generate(
                query=query,
                use_context=True
            )
            
            assert response is not None
            assert isinstance(response, str)
            assert len(response) > 0
        except Exception as e:
            pytest.skip(f"RAG test skipped: {str(e)}")
    
    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_streaming_response(self, clear_cache):
        query = "What is machine learning?"
        
        chunks = []
        async for chunk in rag_pipeline.stream(
            query=query,
            use_context=False
        ):
            chunks.append(chunk)
        
        assert len(chunks) > 0
        full_response = "".join(chunks)
        assert len(full_response) > 0
    
    @pytest.mark.asyncio
    async def test_pipeline_with_session(self, clear_cache):
        session_id = "test-session-123"
        
        response1 = await rag_pipeline.generate(
            query="My name is Alice",
            session_id=session_id,
            use_context=False
        )
        
        assert response1 is not None
        
        messages = rag_pipeline.memory.get_messages(session_id)
        assert len(messages) >= 2
    
    @pytest.mark.asyncio
    async def test_cache_integration(self, clear_cache):
        query = "What is Python?"
        
        response1 = await rag_pipeline.generate(
            query=query,
            use_context=False
        )
        
        response2 = await rag_pipeline.generate(
            query=query,
            use_context=False
        )
        
        assert response1 == response2