|
|
""" |
|
|
Retriever Module for Module C |
|
|
Finds relevant templates based on user query/intent. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import List, Dict, Any |
|
|
from .vector_db import TemplateVectorDB |
|
|
from module_a.embeddings import EmbeddingGenerator |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TemplateRetriever: |
|
|
""" |
|
|
Retrieves the most relevant letter templates for a given user query. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.db = TemplateVectorDB() |
|
|
self.embedder = EmbeddingGenerator() |
|
|
|
|
|
def retrieve_templates(self, query: str, k: int = 1) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Retrieve top-k templates matching the query. |
|
|
""" |
|
|
logger.info(f"Retrieving templates for query: {query}") |
|
|
|
|
|
|
|
|
query_embedding = self.embedder.generate_embedding(query) |
|
|
|
|
|
|
|
|
results = self.db.query_with_embedding(query_embedding.tolist(), n_results=k) |
|
|
|
|
|
|
|
|
retrieved = [] |
|
|
if results['documents'][0]: |
|
|
for i, doc in enumerate(results['documents'][0]): |
|
|
metadata = results['metadatas'][0][i] |
|
|
distance = results['distances'][0][i] |
|
|
|
|
|
retrieved.append({ |
|
|
"filename": results['ids'][0][i], |
|
|
"content": doc, |
|
|
"metadata": metadata, |
|
|
"score": 1.0 - distance |
|
|
}) |
|
|
|
|
|
logger.info(f"Found {len(retrieved)} templates.") |
|
|
return retrieved |
|
|
|