Spaces:
Sleeping
Sleeping
| from llama_index.llms.google_genai import GoogleGenAI | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from sentence_transformers import CrossEncoder | |
| from my_logging import log_message | |
| def get_llm_model(api_key, model_name="gemini-2.0-flash"): | |
| """Get LLM model""" | |
| return GoogleGenAI(model=model_name, api_key=api_key) | |
| def get_embedding_model(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"): | |
| """Get embedding model""" | |
| return HuggingFaceEmbedding(model_name=model_name) | |
| def get_reranker_model(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2'): | |
| """Get reranker model""" | |
| return CrossEncoder(model_name) | |
| def format_sources(nodes): | |
| """Format retrieved sources for display""" | |
| sources = [] | |
| for node in nodes: | |
| meta = node.metadata | |
| doc_type = meta.get('type', 'text') | |
| doc_id = meta.get('document_id', 'unknown') | |
| if doc_type == 'table': | |
| table_num = meta.get('table_number', 'unknown') | |
| title = meta.get('table_title', '') | |
| sources.append(f"📊 {doc_id} - Таблица {table_num}: {title}") | |
| elif doc_type == 'image': | |
| img_num = meta.get('image_number', 'unknown') | |
| sources.append(f"🖼️ {doc_id} - Рисунок {img_num}") | |
| else: | |
| section = meta.get('section_id', '') | |
| sources.append(f"📄 {doc_id} - Раздел {section}") | |
| return "\n".join(set(sources)) | |
| import re | |
| def answer_question(question, query_engine, reranker): | |
| try: | |
| log_message(f"\n{'='*70}") | |
| log_message(f"QUERY: {question}") | |
| retrieved = query_engine.retrieve(question) | |
| log_message(f"RETRIEVED: {len(retrieved)} unique nodes") | |
| reranked = rerank_nodes(question, retrieved, reranker, top_k=20, min_score=-0.5) | |
| log_message(f"RERANKED: {len(reranked)} nodes") | |
| # Group by document and type | |
| doc_groups = {} | |
| for n in reranked: | |
| doc_id = n.metadata.get('document_id', 'unknown') | |
| if doc_id not in doc_groups: | |
| doc_groups[doc_id] = {'tables': [], 'text': [], 'images': []} | |
| node_type = n.metadata.get('type', 'text') | |
| if node_type == 'table': | |
| doc_groups[doc_id]['tables'].append(n) | |
| elif node_type == 'image': | |
| doc_groups[doc_id]['images'].append(n) | |
| else: | |
| doc_groups[doc_id]['text'].append(n) | |
| log_message(f"Documents found: {list(doc_groups.keys())}") | |
| # Format context by document | |
| context_parts = [] | |
| for doc_id, groups in doc_groups.items(): | |
| doc_section = [f"=== ДОКУМЕНТ: {doc_id} ==="] | |
| # Tables first (most important for your queries) | |
| if groups['tables']: | |
| doc_section.append("\n--- ТАБЛИЦЫ ---") | |
| for n in groups['tables']: | |
| meta = n.metadata | |
| table_id = meta.get('table_identifier', meta.get('table_number', 'unknown')) | |
| title = meta.get('table_title', '') | |
| doc_section.append(f"\n[Таблица {table_id}] {title}") | |
| doc_section.append(n.text[:1500]) # Limit length | |
| log_message(f" Included table {table_id} from {doc_id}") | |
| # Then text | |
| if groups['text']: | |
| doc_section.append("\n--- ТЕКСТ ---") | |
| for n in groups['text'][:3]: # Limit text chunks | |
| doc_section.append(n.text[:800]) | |
| log_message(f" Included text section from {doc_id}") | |
| context_parts.append("\n".join(doc_section)) | |
| context = "\n\n" + ("="*70 + "\n\n").join(context_parts) | |
| log_message(f"Context length: {len(context)} chars") | |
| from config import CUSTOM_PROMPT | |
| prompt = CUSTOM_PROMPT.format(context_str=context, query_str=question) | |
| from llama_index.core import Settings | |
| response = Settings.llm.complete(prompt) | |
| sources = format_sources(reranked) | |
| return response.text, sources | |
| except Exception as e: | |
| log_message(f"Error: {e}") | |
| import traceback | |
| log_message(traceback.format_exc()) | |
| return f"Ошибка: {e}", "" | |
| def rerank_nodes(query, nodes, reranker, top_k=20, min_score=0.1): # Much lower threshold | |
| """Rerank with detailed score logging""" | |
| if not nodes or not reranker: | |
| log_message("WARNING: No nodes or reranker available") | |
| return nodes[:top_k] | |
| pairs = [[query, n.text[:500]] for n in nodes] # Limit text length for reranker | |
| scores = reranker.predict(pairs) | |
| scored = sorted(zip(nodes, scores), key=lambda x: x[1], reverse=True) | |
| # Detailed logging | |
| if scored: | |
| top_5_scores = [s for _, s in scored[:5]] | |
| bottom_5_scores = [s for _, s in scored[-5:]] | |
| log_message(f"Score range: {min(scores):.3f} to {max(scores):.3f}") | |
| log_message(f"Top 5 scores: {top_5_scores}") | |
| log_message(f"Bottom 5 scores: {bottom_5_scores}") | |
| # Count how many pass threshold | |
| above_threshold = sum(1 for _, s in scored if s >= min_score) | |
| log_message(f"Nodes above threshold ({min_score}): {above_threshold}/{len(scored)}") | |
| filtered = [n for n, s in scored if s >= min_score] | |
| result = filtered[:top_k] if filtered else [n for n, _ in scored[:top_k]] | |
| log_message(f"Returning {len(result)} nodes after reranking") | |
| return result |