RAG_AIEXP_01 / utils.py
MrSimple07's picture
top k reranker = 20, max rows = 10, max chars= 2000 + new deduplication
ec64429
raw
history blame
5.61 kB
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from sentence_transformers import CrossEncoder
from my_logging import log_message
def get_llm_model(api_key, model_name="gemini-2.0-flash"):
"""Get LLM model"""
return GoogleGenAI(model=model_name, api_key=api_key)
def get_embedding_model(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"):
"""Get embedding model"""
return HuggingFaceEmbedding(model_name=model_name)
def get_reranker_model(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2'):
"""Get reranker model"""
return CrossEncoder(model_name)
def format_sources(nodes):
"""Format retrieved sources for display"""
sources = []
for node in nodes:
meta = node.metadata
doc_type = meta.get('type', 'text')
doc_id = meta.get('document_id', 'unknown')
if doc_type == 'table':
table_num = meta.get('table_number', 'unknown')
title = meta.get('table_title', '')
sources.append(f"📊 {doc_id} - Таблица {table_num}: {title}")
elif doc_type == 'image':
img_num = meta.get('image_number', 'unknown')
sources.append(f"🖼️ {doc_id} - Рисунок {img_num}")
else:
section = meta.get('section_id', '')
sources.append(f"📄 {doc_id} - Раздел {section}")
return "\n".join(set(sources))
import re
def answer_question(question, query_engine, reranker):
try:
log_message(f"\n{'='*70}")
log_message(f"QUERY: {question}")
retrieved = query_engine.retrieve(question)
log_message(f"RETRIEVED: {len(retrieved)} unique nodes")
reranked = rerank_nodes(question, retrieved, reranker, top_k=20, min_score=-0.5)
log_message(f"RERANKED: {len(reranked)} nodes")
# Group by document and type
doc_groups = {}
for n in reranked:
doc_id = n.metadata.get('document_id', 'unknown')
if doc_id not in doc_groups:
doc_groups[doc_id] = {'tables': [], 'text': [], 'images': []}
node_type = n.metadata.get('type', 'text')
if node_type == 'table':
doc_groups[doc_id]['tables'].append(n)
elif node_type == 'image':
doc_groups[doc_id]['images'].append(n)
else:
doc_groups[doc_id]['text'].append(n)
log_message(f"Documents found: {list(doc_groups.keys())}")
# Format context by document
context_parts = []
for doc_id, groups in doc_groups.items():
doc_section = [f"=== ДОКУМЕНТ: {doc_id} ==="]
# Tables first (most important for your queries)
if groups['tables']:
doc_section.append("\n--- ТАБЛИЦЫ ---")
for n in groups['tables']:
meta = n.metadata
table_id = meta.get('table_identifier', meta.get('table_number', 'unknown'))
title = meta.get('table_title', '')
doc_section.append(f"\n[Таблица {table_id}] {title}")
doc_section.append(n.text[:1500]) # Limit length
log_message(f" Included table {table_id} from {doc_id}")
# Then text
if groups['text']:
doc_section.append("\n--- ТЕКСТ ---")
for n in groups['text'][:3]: # Limit text chunks
doc_section.append(n.text[:800])
log_message(f" Included text section from {doc_id}")
context_parts.append("\n".join(doc_section))
context = "\n\n" + ("="*70 + "\n\n").join(context_parts)
log_message(f"Context length: {len(context)} chars")
from config import CUSTOM_PROMPT
prompt = CUSTOM_PROMPT.format(context_str=context, query_str=question)
from llama_index.core import Settings
response = Settings.llm.complete(prompt)
sources = format_sources(reranked)
return response.text, sources
except Exception as e:
log_message(f"Error: {e}")
import traceback
log_message(traceback.format_exc())
return f"Ошибка: {e}", ""
def rerank_nodes(query, nodes, reranker, top_k=20, min_score=0.1): # Much lower threshold
"""Rerank with detailed score logging"""
if not nodes or not reranker:
log_message("WARNING: No nodes or reranker available")
return nodes[:top_k]
pairs = [[query, n.text[:500]] for n in nodes] # Limit text length for reranker
scores = reranker.predict(pairs)
scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True)
# Detailed logging
if scored:
top_5_scores = [s for _, s in scored[:5]]
bottom_5_scores = [s for _, s in scored[-5:]]
log_message(f"Score range: {min(scores):.3f} to {max(scores):.3f}")
log_message(f"Top 5 scores: {top_5_scores}")
log_message(f"Bottom 5 scores: {bottom_5_scores}")
# Count how many pass threshold
above_threshold = sum(1 for _, s in scored if s >= min_score)
log_message(f"Nodes above threshold ({min_score}): {above_threshold}/{len(scored)}")
filtered = [n for n, s in scored if s >= min_score]
result = filtered[:top_k] if filtered else [n for n, _ in scored[:top_k]]
log_message(f"Returning {len(result)} nodes after reranking")
return result