File size: 3,108 Bytes
d557d77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Retrieval module for semantic search."""

from typing import Optional

from coderag.config import get_settings
from coderag.indexing.embeddings import EmbeddingGenerator
from coderag.indexing.vectorstore import VectorStore
from coderag.logging import get_logger
from coderag.models.chunk import Chunk
from coderag.models.response import RetrievedChunk

logger = get_logger(__name__)


class Retriever:
    """Retrieves relevant chunks for a query."""

    def __init__(
        self,
        vectorstore: Optional[VectorStore] = None,
        embedder: Optional[EmbeddingGenerator] = None,
    ) -> None:
        settings = get_settings()
        self.vectorstore = vectorstore or VectorStore()
        self.embedder = embedder or EmbeddingGenerator()
        self.default_top_k = settings.retrieval.default_top_k
        self.max_top_k = settings.retrieval.max_top_k
        self.similarity_threshold = settings.retrieval.similarity_threshold

    def retrieve(
        self,
        query: str,
        repo_id: str,
        top_k: Optional[int] = None,
        similarity_threshold: Optional[float] = None,
    ) -> list[RetrievedChunk]:
        top_k = min(top_k or self.default_top_k, self.max_top_k)
        threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold

        logger.info("Retrieving chunks", query=query[:100], repo_id=repo_id, top_k=top_k)

        # Generate query embedding
        query_embedding = self.embedder.generate_embedding(query, is_query=True)

        # Search vector store
        results = self.vectorstore.query(
            query_embedding=query_embedding,
            repo_id=repo_id,
            top_k=top_k,
            similarity_threshold=threshold,
        )

        # Convert to RetrievedChunk
        retrieved_chunks = []
        for chunk, score in results:
            retrieved_chunk = RetrievedChunk(
                chunk_id=chunk.id,
                content=chunk.content,
                file_path=chunk.file_path,
                start_line=chunk.start_line,
                end_line=chunk.end_line,
                relevance_score=score,
                chunk_type=chunk.chunk_type.value,
                name=chunk.name,
            )
            retrieved_chunks.append(retrieved_chunk)

        logger.info("Chunks retrieved", count=len(retrieved_chunks))
        return retrieved_chunks

    def retrieve_with_context(
        self,
        query: str,
        repo_id: str,
        top_k: Optional[int] = None,
    ) -> tuple[list[RetrievedChunk], str]:
        chunks = self.retrieve(query, repo_id, top_k)

        # Build context string for LLM
        context_parts = []
        for i, chunk in enumerate(chunks, 1):
            context_parts.append(
                f"[{i}] {chunk.citation}\n"
                f"Type: {chunk.chunk_type}"
                f"{f' | Name: {chunk.name}' if chunk.name else ''}\n"
                f"```\n{chunk.content}\n```\n"
            )

        context = "\n".join(context_parts) if context_parts else "No relevant code found."

        return chunks, context