Spaces:
Build error
Build error
| import opik | |
| from langchain_ollama import ChatOllama | |
| from loguru import logger | |
| from llm_engineering.domain.queries import Query | |
| from llm_engineering.settings import settings | |
| from .base import RAGStep | |
| from .prompt_templates import QueryExpansionTemplate | |
| class QueryExpansion(RAGStep): | |
| def generate(self, query: Query, expand_to_n: int) -> list[Query]: | |
| assert expand_to_n > 0, f"'expand_to_n' should be greater than 0. Got {expand_to_n}." | |
| if self._mock: | |
| return [query for _ in range(expand_to_n)] | |
| query_expansion_template = QueryExpansionTemplate() | |
| prompt = query_expansion_template.create_template(expand_to_n - 1) | |
| model = ChatOllama(model=settings.LLAMA_MODEL_ID, temperature=0) | |
| chain = prompt | model | |
| response = chain.invoke({"question": query}) | |
| result = response.content | |
| queries_content = result.strip().split(query_expansion_template.separator) | |
| queries = [query] | |
| queries += [ | |
| query.replace_content(stripped_content) | |
| for content in queries_content | |
| if (stripped_content := content.strip()) | |
| ] | |
| return queries | |
| if __name__ == "__main__": | |
| query = Query.from_str("Write an article about the best types of advanced RAG methods.") | |
| query_expander = QueryExpansion() | |
| expanded_queries = query_expander.generate(query, expand_to_n=3) | |
| for expanded_query in expanded_queries: | |
| logger.info(expanded_query.content) | |