rag_template / tests /test_query_expansion.py
Guilherme Favaron
Major update: Add hybrid search, reranking, multiple LLMs, and UI improvements
1b447de
"""
Testes para módulo de expansão de queries
"""
import pytest
from src.query_expansion import QueryExpander
from src.generation import GenerationManager
class TestQueryExpander:
"""Testes para classe QueryExpander"""
@pytest.fixture
def generation_manager(self):
"""Fixture para GenerationManager"""
return GenerationManager()
@pytest.fixture
def expander(self, generation_manager):
"""Fixture para QueryExpander"""
return QueryExpander(generation_manager)
def test_initialization(self, generation_manager):
"""Testa inicialização"""
expander = QueryExpander(generation_manager)
assert expander.generation_manager is not None
def test_expand_query_template(self, expander):
"""Testa expansão com templates"""
query = "machine learning"
variations = expander.expand_query(query, num_variations=3, method="template")
assert len(variations) > 0
assert query in variations # Query original deve estar incluída
assert len(variations) <= 4 # Original + 3 variações
def test_expand_query_paraphrase(self, expander):
"""Testa expansão com paraphrase"""
query = "o que é inteligência artificial?"
variations = expander.expand_query(query, num_variations=2, method="paraphrase")
assert len(variations) > 0
assert isinstance(variations, list)
assert all(isinstance(v, str) for v in variations)
def test_expand_query_unknown_method(self, expander):
"""Testa método desconhecido retorna query original"""
query = "test query"
variations = expander.expand_query(query, num_variations=3, method="unknown")
assert variations == [query]
def test_parse_llm_variations_numbered(self, expander):
"""Testa parsing de variações numeradas"""
response = """
1. What is machine learning?
2. How does machine learning work?
3. Explain machine learning concepts
"""
variations = expander._parse_llm_variations(response)
assert len(variations) == 3
assert "What is machine learning?" in variations
assert "How does machine learning work?" in variations
assert "Explain machine learning concepts" in variations
def test_parse_llm_variations_bullets(self, expander):
"""Testa parsing de variações com bullets"""
response = """
- Machine learning definition
- What is ML?
* How ML algorithms work
"""
variations = expander._parse_llm_variations(response)
assert len(variations) >= 2 # Pelo menos os com - e *
def test_parse_llm_variations_empty(self, expander):
"""Testa parsing de response vazio"""
response = ""
variations = expander._parse_llm_variations(response)
assert variations == []
def test_template_expansion_preserves_original(self, expander):
"""Testa que expansão template preserva query original"""
query = "Python programming"
variations = expander._expand_with_templates(query, num_variations=3)
assert query in variations
assert variations[0] == query # Original é o primeiro
def test_paraphrase_expansion_basic(self, expander):
"""Testa expansão básica com paraphrase"""
query = "o que é deep learning?"
variations = expander._expand_with_paraphrase(query, num_variations=2)
assert len(variations) > 0
assert query in variations
def test_paraphrase_substitutions(self, expander):
"""Testa substituições de paraphrase"""
query = "explique machine learning"
variations = expander._expand_with_paraphrase(query, num_variations=3)
# Deve gerar variação com "descreva" se tiver "explique"
has_variation = any("descreva" in v.lower() for v in variations)
# Nota: Pode não gerar se limite de variações for atingido
assert isinstance(variations, list)
def test_get_expansion_info_llm(self, expander):
"""Testa informações sobre método LLM"""
info = expander.get_expansion_info("llm")
assert "name" in info
assert "description" in info
assert "pros" in info
assert "cons" in info
assert "best_for" in info
assert info["type"] == "cross-encoder" or info["name"] == "LLM-based"
def test_get_expansion_info_template(self, expander):
"""Testa informações sobre método template"""
info = expander.get_expansion_info("template")
assert info["name"] == "Template-based"
assert "rápido" in info["pros"].lower() or "fast" in info["pros"].lower()
def test_get_expansion_info_paraphrase(self, expander):
"""Testa informações sobre método paraphrase"""
info = expander.get_expansion_info("paraphrase")
assert info["name"] == "Paraphrase-based"
assert "description" in info
def test_get_expansion_info_unknown(self, expander):
"""Testa informações sobre método desconhecido"""
info = expander.get_expansion_info("unknown_method")
assert "name" in info
assert info["name"] == "unknown_method"
def test_expansion_returns_strings(self, expander):
"""Testa que expansão sempre retorna strings"""
query = "test"
for method in ["template", "paraphrase"]:
variations = expander.expand_query(query, num_variations=2, method=method)
assert all(isinstance(v, str) for v in variations)
def test_expansion_num_variations_respected(self, expander):
"""Testa que número de variações é respeitado (aproximadamente)"""
query = "artificial intelligence"
num_vars = 3
# Template deve respeitar limite
variations = expander._expand_with_templates(query, num_vars)
assert len(variations) <= num_vars + 1 # +1 para original
class TestQueryExpansionIntegration:
"""Testes de integração para query expansion"""
@pytest.fixture
def generation_manager(self):
"""Fixture para GenerationManager"""
return GenerationManager()
@pytest.fixture
def expander(self, generation_manager):
"""Fixture para QueryExpander"""
return QueryExpander(generation_manager)
def test_llm_expansion_with_real_query(self, expander):
"""Testa expansão LLM com query real (pode falhar se LLM não disponível)"""
query = "What is Python programming?"
try:
variations = expander.expand_query(query, num_variations=2, method="llm")
# Se LLM está disponível, deve gerar variações
assert len(variations) > 0
# Pelo menos a query original deve estar presente
assert query in variations or len(variations) >= 1
except Exception as e:
# Se LLM não está disponível, teste passa
pytest.skip(f"LLM não disponível: {e}")
def test_different_methods_produce_different_results(self, expander):
"""Testa que métodos diferentes produzem resultados diferentes"""
query = "machine learning algorithms"
template_vars = expander.expand_query(query, num_variations=2, method="template")
paraphrase_vars = expander.expand_query(query, num_variations=2, method="paraphrase")
# Resultados devem ser diferentes (exceto query original)
# Nota: Pode haver overlap, mas conjuntos devem ser diferentes
assert isinstance(template_vars, list)
assert isinstance(paraphrase_vars, list)
def test_expansion_handles_special_characters(self, expander):
"""Testa que expansão lida com caracteres especiais"""
query = "O que é IA? E ML?"
for method in ["template", "paraphrase"]:
variations = expander.expand_query(query, num_variations=2, method=method)
assert len(variations) > 0
assert all(isinstance(v, str) for v in variations)
def test_expansion_handles_long_queries(self, expander):
"""Testa que expansão lida com queries longas"""
query = "Explain the differences between supervised learning, unsupervised learning, and reinforcement learning in machine learning"
variations = expander.expand_query(query, num_variations=2, method="template")
assert len(variations) > 0
assert query in variations