Spaces:
Sleeping
Sleeping
soft.engineer commited on
Commit ·
5ee86b8
1
Parent(s): 19167e3
initial project
Browse files- .env.example +16 -0
- .gitignore +63 -0
- E2E_TEST_DESIGN.md +544 -0
- EVALUATION_GUIDE.md +323 -0
- HUGGINGFACE_DEPLOYMENT.md +325 -0
- MODEL_DB_CONFIG.md +451 -0
- README.md +605 -4
- SETUP_GUIDE.md +512 -0
- app.py +1552 -0
- core/comparison.py +148 -0
- core/eval.py +420 -0
- core/index.py +389 -0
- core/ingest.py +667 -0
- core/report_generator.py +386 -0
- core/reranker.py +183 -0
- core/retrieval.py +507 -0
- core/session_manager.py +353 -0
- core/session_rag.py +100 -0
- core/tag_generator.py +464 -0
- core/utils.py +148 -0
- core/visualization.py +291 -0
- requirements.txt +55 -0
- tests/README.md +283 -0
- tests/__init__.py +2 -0
- tests/conftest.py +154 -0
- tests/test_accuracy.py +231 -0
- tests/test_japanese_support.py +87 -0
- tests/test_mcp_server.py +263 -0
- tests/test_robustness.py +194 -0
- tests/test_user_scenarios.py +174 -0
- tests/test_ux.py +156 -0
.env.example
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENAI_API_KEY=your_openai_api_key_here
|
| 2 |
+
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# OPTIONAL CONFIGURATION VARIABLES
|
| 5 |
+
# (Can be set but have defaults - see README for full documentation)
|
| 6 |
+
# =============================================================================
|
| 7 |
+
|
| 8 |
+
# These variables are optional and have sensible defaults.
|
| 9 |
+
# Uncomment and modify only if you need to override defaults:
|
| 10 |
+
|
| 11 |
+
# OPENAI_MODEL=gpt-4o-mini
|
| 12 |
+
# OPENAI_EMBED_MODEL=text-embedding-3-small
|
| 13 |
+
# ST_EMBED_MODEL=all-MiniLM-L6-v2
|
| 14 |
+
# CHROMA_PERSIST_DIR=./chroma_data
|
| 15 |
+
# SESSION_TIMEOUT=3600
|
| 16 |
+
# LOG_LEVEL=INFO
|
.gitignore
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual Environment
|
| 24 |
+
venv/
|
| 25 |
+
ENV/
|
| 26 |
+
env/
|
| 27 |
+
.venv
|
| 28 |
+
|
| 29 |
+
# Environment variables (SECURITY)
|
| 30 |
+
.env
|
| 31 |
+
.env.local
|
| 32 |
+
.env.*.local
|
| 33 |
+
|
| 34 |
+
# IDE
|
| 35 |
+
.vscode/
|
| 36 |
+
.idea/
|
| 37 |
+
*.swp
|
| 38 |
+
*.swo
|
| 39 |
+
*~
|
| 40 |
+
.DS_Store
|
| 41 |
+
|
| 42 |
+
# Project specific
|
| 43 |
+
chroma_data/
|
| 44 |
+
reports/
|
| 45 |
+
*.log
|
| 46 |
+
*.sqlite3
|
| 47 |
+
|
| 48 |
+
# Testing
|
| 49 |
+
.pytest_cache/
|
| 50 |
+
.coverage
|
| 51 |
+
htmlcov/
|
| 52 |
+
.tox/
|
| 53 |
+
|
| 54 |
+
# Jupyter Notebook
|
| 55 |
+
.ipynb_checkpoints
|
| 56 |
+
|
| 57 |
+
# Model files (large)
|
| 58 |
+
*.bin
|
| 59 |
+
*.safetensors
|
| 60 |
+
|
| 61 |
+
# Temporary files
|
| 62 |
+
*.tmp
|
| 63 |
+
*.temp
|
E2E_TEST_DESIGN.md
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# E2E Test Design Document
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This document outlines comprehensive End-to-End (E2E) test cases for the Auto Tagging RAG System. The test suite covers accuracy, user experience, and robustness scenarios, including testable cases for non-technical users.
|
| 6 |
+
|
| 7 |
+
## Test Categories
|
| 8 |
+
|
| 9 |
+
### 1. Accuracy Tests
|
| 10 |
+
|
| 11 |
+
#### 1.1 Tag Generation Accuracy
|
| 12 |
+
|
| 13 |
+
**Test Case: TAG-001 - English Document Tag Generation**
|
| 14 |
+
- **Objective**: Verify accurate tag generation for English documents
|
| 15 |
+
- **Steps**:
|
| 16 |
+
1. Upload sample English document (`sample_documents/emergency_procedures.txt`)
|
| 17 |
+
2. Select language: "en" or "Auto"
|
| 18 |
+
3. Click "Build RAG Index"
|
| 19 |
+
4. Verify tags generated match document content
|
| 20 |
+
- **Expected Result**: Tags include relevant keywords like "emergency", "fire", "safety", "procedure"
|
| 21 |
+
- **Success Criteria**: At least 5 relevant tags generated, tags match document topics
|
| 22 |
+
|
| 23 |
+
**Test Case: TAG-002 - Japanese Document Tag Generation**
|
| 24 |
+
- **Objective**: Verify accurate tag generation for Japanese documents
|
| 25 |
+
- **Steps**:
|
| 26 |
+
1. Upload Japanese document (test document)
|
| 27 |
+
2. Select language: "ja" or "Auto"
|
| 28 |
+
3. Click "Build RAG Index"
|
| 29 |
+
4. Verify Japanese tags are generated correctly
|
| 30 |
+
- **Expected Result**: Tags in Japanese characters, relevant to document content
|
| 31 |
+
- **Success Criteria**: Tags are in Japanese, relevant to content
|
| 32 |
+
|
| 33 |
+
**Test Case: TAG-003 - Manual Tag Integration**
|
| 34 |
+
- **Objective**: Verify manual tags are added and prioritized
|
| 35 |
+
- **Steps**:
|
| 36 |
+
1. Upload document
|
| 37 |
+
2. Add manual tags: "custom-tag-1, custom-tag-2"
|
| 38 |
+
3. Build RAG Index
|
| 39 |
+
4. Check tag visualization
|
| 40 |
+
- **Expected Result**: Manual tags appear first in tag list, combined with auto-generated tags
|
| 41 |
+
- **Success Criteria**: Manual tags are prepended, no duplicates
|
| 42 |
+
|
| 43 |
+
#### 1.2 Retrieval Accuracy
|
| 44 |
+
|
| 45 |
+
**Test Case: RET-001 - Base RAG Retrieval**
|
| 46 |
+
- **Objective**: Verify Base RAG returns relevant documents
|
| 47 |
+
- **Steps**:
|
| 48 |
+
1. Upload multiple sample documents
|
| 49 |
+
2. Build RAG Index
|
| 50 |
+
3. Search for "emergency procedures"
|
| 51 |
+
4. Verify returned documents are relevant
|
| 52 |
+
- **Expected Result**: Top result contains emergency procedure content
|
| 53 |
+
- **Success Criteria**: Precision@1 > 0.5
|
| 54 |
+
|
| 55 |
+
**Test Case: RET-002 - Tag Filter RAG with OR Operator**
|
| 56 |
+
- **Objective**: Verify tag filtering works with OR operator
|
| 57 |
+
- **Steps**:
|
| 58 |
+
1. Upload documents with known tags
|
| 59 |
+
2. Search with tags: "fire, emergency" (OR operator)
|
| 60 |
+
3. Verify results match any tag
|
| 61 |
+
- **Expected Result**: Results contain documents with "fire" OR "emergency" tags
|
| 62 |
+
- **Success Criteria**: All results have at least one matching tag
|
| 63 |
+
|
| 64 |
+
**Test Case: RET-003 - Tag Filter RAG with AND Operator**
|
| 65 |
+
- **Objective**: Verify tag filtering works with AND operator
|
| 66 |
+
- **Steps**:
|
| 67 |
+
1. Search with tags: "fire, emergency" (AND operator)
|
| 68 |
+
2. Verify results match all tags
|
| 69 |
+
- **Expected Result**: Results contain documents with BOTH "fire" AND "emergency" tags
|
| 70 |
+
- **Success Criteria**: All results have all specified tags
|
| 71 |
+
|
| 72 |
+
**Test Case: RET-004 - Hybrid RAG Weight Tuning**
|
| 73 |
+
- **Objective**: Verify hybrid search combines vector and tag scores correctly
|
| 74 |
+
- **Steps**:
|
| 75 |
+
1. Search with vector_weight=0.7, tag_weight=0.3
|
| 76 |
+
2. Compare results with vector_weight=0.3, tag_weight=0.7
|
| 77 |
+
3. Verify different results based on weights
|
| 78 |
+
- **Expected Result**: Results change based on weight configuration
|
| 79 |
+
- **Success Criteria**: Different weight configurations produce different rankings
|
| 80 |
+
|
| 81 |
+
**Test Case: RET-005 - Hybrid Rerank RAG**
|
| 82 |
+
- **Objective**: Verify reranking improves result relevance
|
| 83 |
+
- **Steps**:
|
| 84 |
+
1. Search with Hybrid RAG (baseline)
|
| 85 |
+
2. Search with Hybrid Rerank RAG
|
| 86 |
+
3. Compare top results
|
| 87 |
+
- **Expected Result**: Reranked results show higher semantic similarity scores
|
| 88 |
+
- **Success Criteria**: nDCG@3 improves with reranking
|
| 89 |
+
|
| 90 |
+
#### 1.3 Evaluation Metrics Accuracy
|
| 91 |
+
|
| 92 |
+
**Test Case: MET-001 - Precision@k Calculation**
|
| 93 |
+
- **Objective**: Verify Precision@k is calculated correctly
|
| 94 |
+
- **Steps**:
|
| 95 |
+
1. Use sample evaluation queries (`sample_evaluation_queries.json`)
|
| 96 |
+
2. Run evaluation with ground truth
|
| 97 |
+
3. Verify Precision@1, Precision@3, Precision@5 values
|
| 98 |
+
- **Expected Result**: Precision@k values match expected ranges (0.0-1.0)
|
| 99 |
+
- **Success Criteria**: Precision@3 > 0.3 for sample data
|
| 100 |
+
|
| 101 |
+
**Test Case: MET-002 - nDCG@k Calculation**
|
| 102 |
+
- **Objective**: Verify nDCG@k considers ranking order
|
| 103 |
+
- **Steps**:
|
| 104 |
+
1. Run evaluation
|
| 105 |
+
2. Compare nDCG@3 and nDCG@5
|
| 106 |
+
- **Expected Result**: nDCG@5 >= nDCG@3 (more results improve ranking)
|
| 107 |
+
- **Success Criteria**: nDCG values increase with k
|
| 108 |
+
|
| 109 |
+
**Test Case: MET-003 - MRR Calculation**
|
| 110 |
+
- **Objective**: Verify MRR reflects first relevant result position
|
| 111 |
+
- **Steps**:
|
| 112 |
+
1. Run evaluation with queries having clear first match
|
| 113 |
+
2. Verify MRR value
|
| 114 |
+
- **Expected Result**: MRR reflects position of first relevant result
|
| 115 |
+
- **Success Criteria**: MRR > 0.3 for sample queries
|
| 116 |
+
|
| 117 |
+
**Test Case: MET-004 - User Satisfaction Integration**
|
| 118 |
+
- **Objective**: Verify user satisfaction scores are recorded
|
| 119 |
+
- **Steps**:
|
| 120 |
+
1. Provide user satisfaction JSON (`sample_user_satisfaction.json`)
|
| 121 |
+
2. Run evaluation
|
| 122 |
+
3. Check results include satisfaction scores
|
| 123 |
+
- **Expected Result**: CSV/JSON reports include user_satisfaction column
|
| 124 |
+
- **Success Criteria**: Satisfaction scores appear in all output formats
|
| 125 |
+
|
| 126 |
+
### 2. User Experience Tests
|
| 127 |
+
|
| 128 |
+
#### 2.1 Non-Technical User Scenarios
|
| 129 |
+
|
| 130 |
+
**Test Case: UX-001 - First-Time User Document Upload**
|
| 131 |
+
- **User Type**: Non-technical user
|
| 132 |
+
- **Objective**: First-time user can upload and process documents
|
| 133 |
+
- **Steps**:
|
| 134 |
+
1. Open application
|
| 135 |
+
2. Navigate to "Upload & Tagging" tab
|
| 136 |
+
3. Drag and drop PDF/TXT files
|
| 137 |
+
4. Leave language as "Auto"
|
| 138 |
+
5. Click "Build RAG Index"
|
| 139 |
+
6. Wait for processing
|
| 140 |
+
- **Expected Result**:
|
| 141 |
+
- Files are uploaded successfully
|
| 142 |
+
- Processing completes without errors
|
| 143 |
+
- Tags are generated and displayed
|
| 144 |
+
- Document count updates
|
| 145 |
+
- **Success Criteria**: User completes upload without reading documentation
|
| 146 |
+
|
| 147 |
+
**Test Case: UX-002 - Simple Search Query**
|
| 148 |
+
- **User Type**: Non-technical user
|
| 149 |
+
- **Objective**: User can search documents without understanding technical details
|
| 150 |
+
- **Steps**:
|
| 151 |
+
1. Upload documents
|
| 152 |
+
2. Go to "Search & Compare" tab
|
| 153 |
+
3. Enter query: "What are emergency procedures?"
|
| 154 |
+
4. Click "Search All Methods"
|
| 155 |
+
5. Review results
|
| 156 |
+
- **Expected Result**: Results appear with readable text, no technical jargon visible
|
| 157 |
+
- **Success Criteria**: User finds relevant information without confusion
|
| 158 |
+
|
| 159 |
+
**Test Case: UX-003 - Chat Interface Usage**
|
| 160 |
+
- **User Type**: Non-technical user
|
| 161 |
+
- **Objective**: User can chat naturally with the system
|
| 162 |
+
- **Steps**:
|
| 163 |
+
1. Go to "Chat Interface" tab
|
| 164 |
+
2. Type: "Tell me about fire safety"
|
| 165 |
+
3. Click "Send"
|
| 166 |
+
4. Review answer and sources
|
| 167 |
+
- **Expected Result**: Natural language answer, sources visible in accordion
|
| 168 |
+
- **Success Criteria**: Answer is clear and helpful, sources are accessible
|
| 169 |
+
|
| 170 |
+
**Test Case: UX-004 - Evaluation for Non-Technical User**
|
| 171 |
+
- **User Type**: Non-technical user
|
| 172 |
+
- **Objective**: User can run basic evaluation with sample data
|
| 173 |
+
- **Steps**:
|
| 174 |
+
1. Copy sample queries from `sample_evaluation_queries.json`
|
| 175 |
+
2. Paste into "Evaluation Queries" field
|
| 176 |
+
3. Click "Run Evaluation"
|
| 177 |
+
4. View results and charts
|
| 178 |
+
- **Expected Result**: Charts display, results are understandable
|
| 179 |
+
- **Success Criteria**: User can interpret charts without technical knowledge
|
| 180 |
+
|
| 181 |
+
**Test Case: UX-005 - Session Persistence**
|
| 182 |
+
- **User Type**: Non-technical user
|
| 183 |
+
- **Objective**: User's data persists across browser refresh
|
| 184 |
+
- **Steps**:
|
| 185 |
+
1. Upload documents and build index
|
| 186 |
+
2. Note document count
|
| 187 |
+
3. Refresh browser page
|
| 188 |
+
4. Check session and document count
|
| 189 |
+
- **Expected Result**: Same session ID, same document count, data accessible
|
| 190 |
+
- **Success Criteria**: No data loss after refresh
|
| 191 |
+
|
| 192 |
+
#### 2.2 Advanced User Scenarios
|
| 193 |
+
|
| 194 |
+
**Test Case: UX-006 - Tag Weight Tuning**
|
| 195 |
+
- **User Type**: Technical user
|
| 196 |
+
- **Objective**: Advanced user can tune hybrid search weights
|
| 197 |
+
- **Steps**:
|
| 198 |
+
1. Go to "Search & Compare" tab
|
| 199 |
+
2. Adjust vector weight slider (0.0-1.0)
|
| 200 |
+
3. Adjust tag weight slider (0.0-1.0)
|
| 201 |
+
4. Search and compare results
|
| 202 |
+
- **Expected Result**: Results change based on weight configuration
|
| 203 |
+
- **Success Criteria**: Weights affect result ranking visibly
|
| 204 |
+
|
| 205 |
+
**Test Case: UX-007 - Custom Tag Input**
|
| 206 |
+
- **User Type**: Technical user
|
| 207 |
+
- **Objective**: User can add custom tags during upload
|
| 208 |
+
- **Steps**:
|
| 209 |
+
1. Upload document
|
| 210 |
+
2. Enter custom tags: "project-alpha, confidential"
|
| 211 |
+
3. Build index
|
| 212 |
+
4. Verify tags in visualization
|
| 213 |
+
- **Expected Result**: Custom tags appear in tag list, used in filtering
|
| 214 |
+
- **Success Criteria**: Custom tags work in tag-based search
|
| 215 |
+
|
| 216 |
+
**Test Case: UX-008 - Export Functionality**
|
| 217 |
+
- **User Type**: Technical user
|
| 218 |
+
- **Objective**: User can export evaluation results
|
| 219 |
+
- **Steps**:
|
| 220 |
+
1. Run evaluation
|
| 221 |
+
2. Click "Download CSV"
|
| 222 |
+
3. Click "Download Charts"
|
| 223 |
+
- **Expected Result**: Files download automatically without extra clicks
|
| 224 |
+
- **Success Criteria**: One-click download works, files are valid
|
| 225 |
+
|
| 226 |
+
**Test Case: UX-009 - Multi-Document Processing**
|
| 227 |
+
- **User Type**: Advanced user
|
| 228 |
+
- **Objective**: User can process multiple documents simultaneously
|
| 229 |
+
- **Steps**:
|
| 230 |
+
1. Upload 5+ documents
|
| 231 |
+
2. Build index
|
| 232 |
+
3. Verify all documents indexed
|
| 233 |
+
- **Expected Result**: All documents processed, unique document count correct
|
| 234 |
+
- **Success Criteria**: Document count matches number of uploaded files
|
| 235 |
+
|
| 236 |
+
#### 2.3 UI/UX Quality Tests
|
| 237 |
+
|
| 238 |
+
**Test Case: UX-010 - Responsive Design**
|
| 239 |
+
- **Objective**: UI works on different screen sizes
|
| 240 |
+
- **Steps**:
|
| 241 |
+
1. Test on desktop (1920x1080)
|
| 242 |
+
2. Test on tablet (768x1024)
|
| 243 |
+
3. Test on mobile (375x667)
|
| 244 |
+
- **Expected Result**: All tabs accessible, forms usable, no horizontal scroll
|
| 245 |
+
- **Success Criteria**: UI adapts to screen size
|
| 246 |
+
|
| 247 |
+
**Test Case: UX-011 - Loading States**
|
| 248 |
+
- **Objective**: Users see feedback during processing
|
| 249 |
+
- **Steps**:
|
| 250 |
+
1. Upload large document (50+ pages)
|
| 251 |
+
2. Observe UI during processing
|
| 252 |
+
- **Expected Result**: Progress indicators visible, status messages clear
|
| 253 |
+
- **Success Criteria**: User understands system is working
|
| 254 |
+
|
| 255 |
+
**Test Case: UX-012 - Error Messages**
|
| 256 |
+
- **Objective**: Errors are user-friendly
|
| 257 |
+
- **Steps**:
|
| 258 |
+
1. Upload invalid file (corrupted PDF)
|
| 259 |
+
2. Search with empty query
|
| 260 |
+
3. Run evaluation with invalid JSON
|
| 261 |
+
- **Expected Result**: Clear error messages, actionable guidance
|
| 262 |
+
- **Success Criteria**: Errors help user fix issues
|
| 263 |
+
|
| 264 |
+
### 3. Robustness Tests
|
| 265 |
+
|
| 266 |
+
#### 3.1 Error Handling
|
| 267 |
+
|
| 268 |
+
**Test Case: ROB-001 - Invalid File Upload**
|
| 269 |
+
- **Objective**: System handles invalid files gracefully
|
| 270 |
+
- **Steps**:
|
| 271 |
+
1. Upload corrupted PDF file
|
| 272 |
+
2. Upload non-PDF/TXT file (e.g., .exe)
|
| 273 |
+
3. Upload empty file
|
| 274 |
+
- **Expected Result**:
|
| 275 |
+
- Error message displayed
|
| 276 |
+
- System remains functional
|
| 277 |
+
- No crashes
|
| 278 |
+
- **Success Criteria**: Graceful error handling, no crashes
|
| 279 |
+
|
| 280 |
+
**Test Case: ROB-002 - Invalid JSON in Evaluation**
|
| 281 |
+
- **Objective**: System handles malformed JSON
|
| 282 |
+
- **Steps**:
|
| 283 |
+
1. Enter invalid JSON in evaluation queries
|
| 284 |
+
2. Click "Run Evaluation"
|
| 285 |
+
- **Expected Result**: Clear error message about JSON format
|
| 286 |
+
- **Success Criteria**: Error is helpful, system recovers
|
| 287 |
+
|
| 288 |
+
**Test Case: ROB-003 - Empty Query Search**
|
| 289 |
+
- **Objective**: System handles empty search queries
|
| 290 |
+
- **Steps**:
|
| 291 |
+
1. Leave search query empty
|
| 292 |
+
2. Click "Search All Methods"
|
| 293 |
+
- **Expected Result**: Error message or no results message
|
| 294 |
+
- **Success Criteria**: No crashes, clear feedback
|
| 295 |
+
|
| 296 |
+
**Test Case: ROB-004 - Missing Ground Truth in Evaluation**
|
| 297 |
+
- **Objective**: System handles missing ground truth
|
| 298 |
+
- **Steps**:
|
| 299 |
+
1. Create evaluation query without ground_truth field
|
| 300 |
+
2. Run evaluation
|
| 301 |
+
- **Expected Result**: Evaluation runs, metrics are 0 or skipped
|
| 302 |
+
- **Success Criteria**: System continues, no errors
|
| 303 |
+
|
| 304 |
+
**Test Case: ROB-005 - Large Document Processing**
|
| 305 |
+
- **Objective**: System handles large documents
|
| 306 |
+
- **Steps**:
|
| 307 |
+
1. Upload 100+ page PDF
|
| 308 |
+
2. Build index
|
| 309 |
+
3. Monitor memory usage
|
| 310 |
+
- **Expected Result**: Processing completes, no memory errors
|
| 311 |
+
- **Success Criteria**: Large documents process successfully
|
| 312 |
+
|
| 313 |
+
#### 3.2 Edge Cases
|
| 314 |
+
|
| 315 |
+
**Test Case: ROB-006 - Very Short Documents**
|
| 316 |
+
- **Objective**: System handles minimal content
|
| 317 |
+
- **Steps**:
|
| 318 |
+
1. Upload document with 1 sentence
|
| 319 |
+
2. Build index
|
| 320 |
+
3. Search for content
|
| 321 |
+
- **Expected Result**: Tags generated, search works
|
| 322 |
+
- **Success Criteria**: Minimal content is processed
|
| 323 |
+
|
| 324 |
+
**Test Case: ROB-007 - Special Characters in Documents**
|
| 325 |
+
- **Objective**: System handles special characters
|
| 326 |
+
- **Steps**:
|
| 327 |
+
1. Upload document with special characters (é, 日本語, 🎉)
|
| 328 |
+
2. Build index
|
| 329 |
+
3. Search with special characters
|
| 330 |
+
- **Expected Result**: Special characters preserved, search works
|
| 331 |
+
- **Success Criteria**: Unicode handled correctly
|
| 332 |
+
|
| 333 |
+
**Test Case: ROB-008 - Concurrent Sessions**
|
| 334 |
+
- **Objective**: Multiple users can use system simultaneously
|
| 335 |
+
- **Steps**:
|
| 336 |
+
1. Open application in two browser windows
|
| 337 |
+
2. Upload different documents in each
|
| 338 |
+
3. Verify isolation
|
| 339 |
+
- **Expected Result**: Each session has separate data, no interference
|
| 340 |
+
- **Success Criteria**: Session isolation works correctly
|
| 341 |
+
|
| 342 |
+
**Test Case: ROB-009 - Session Expiration Handling**
|
| 343 |
+
- **Objective**: System handles expired sessions
|
| 344 |
+
- **Steps**:
|
| 345 |
+
1. Create session
|
| 346 |
+
2. Wait for expiration (if configured)
|
| 347 |
+
3. Try to access session
|
| 348 |
+
- **Expected Result**: New session created or session restored
|
| 349 |
+
- **Success Criteria**: Graceful session handling
|
| 350 |
+
|
| 351 |
+
**Test Case: ROB-010 - Network Interruption**
|
| 352 |
+
- **Objective**: System handles offline mode
|
| 353 |
+
- **Steps**:
|
| 354 |
+
1. Disconnect network
|
| 355 |
+
2. Upload documents
|
| 356 |
+
3. Search documents
|
| 357 |
+
- **Expected Result**: Works offline (after initial model downloads)
|
| 358 |
+
- **Success Criteria**: Offline functionality works
|
| 359 |
+
|
| 360 |
+
#### 3.3 Data Integrity
|
| 361 |
+
|
| 362 |
+
**Test Case: ROB-011 - Document Count Accuracy**
|
| 363 |
+
- **Objective**: Document count reflects unique documents
|
| 364 |
+
- **Steps**:
|
| 365 |
+
1. Upload 3 documents
|
| 366 |
+
2. Check document count
|
| 367 |
+
3. Upload 2 more documents
|
| 368 |
+
4. Verify count updates to 5
|
| 369 |
+
- **Expected Result**: Count matches unique documents, not chunks
|
| 370 |
+
- **Success Criteria**: Accurate document counting
|
| 371 |
+
|
| 372 |
+
**Test Case: ROB-012 - Tag Consistency**
|
| 373 |
+
- **Objective**: Same document produces consistent tags
|
| 374 |
+
- **Steps**:
|
| 375 |
+
1. Upload document
|
| 376 |
+
2. Note tags
|
| 377 |
+
3. Reset index
|
| 378 |
+
4. Upload same document again
|
| 379 |
+
5. Compare tags
|
| 380 |
+
- **Expected Result**: Tags are consistent (may vary slightly due to randomness)
|
| 381 |
+
- **Success Criteria**: Similar tags generated for same document
|
| 382 |
+
|
| 383 |
+
**Test Case: ROB-013 - Index Reset**
|
| 384 |
+
- **Objective**: Reset clears all data correctly
|
| 385 |
+
- **Steps**:
|
| 386 |
+
1. Upload documents
|
| 387 |
+
2. Build index
|
| 388 |
+
3. Click "Reset Index"
|
| 389 |
+
4. Verify document count is 0
|
| 390 |
+
5. Try to search
|
| 391 |
+
- **Expected Result**: Count resets, search returns no results
|
| 392 |
+
- **Success Criteria**: Complete reset functionality
|
| 393 |
+
|
| 394 |
+
### 4. Performance Tests
|
| 395 |
+
|
| 396 |
+
#### 4.1 Response Time
|
| 397 |
+
|
| 398 |
+
**Test Case: PERF-001 - Document Upload Speed**
|
| 399 |
+
- **Objective**: Upload completes in reasonable time
|
| 400 |
+
- **Steps**:
|
| 401 |
+
1. Upload 10 documents (total ~1MB)
|
| 402 |
+
2. Measure time to completion
|
| 403 |
+
- **Expected Result**: Processing completes within 30 seconds
|
| 404 |
+
- **Success Criteria**: < 30s for 10 documents
|
| 405 |
+
|
| 406 |
+
**Test Case: PERF-002 - Search Latency**
|
| 407 |
+
- **Objective**: Search returns results quickly
|
| 408 |
+
- **Steps**:
|
| 409 |
+
1. Index 100 documents
|
| 410 |
+
2. Measure search latency for each method
|
| 411 |
+
- **Expected Result**:
|
| 412 |
+
- Base RAG: < 1s
|
| 413 |
+
- Tag Filter: < 1s
|
| 414 |
+
- Hybrid: < 2s
|
| 415 |
+
- Hybrid Rerank: < 5s
|
| 416 |
+
- **Success Criteria**: All methods meet latency targets
|
| 417 |
+
|
| 418 |
+
**Test Case: PERF-003 - Evaluation Speed**
|
| 419 |
+
- **Objective**: Evaluation completes in reasonable time
|
| 420 |
+
- **Steps**:
|
| 421 |
+
1. Run evaluation with 10 queries, 3 k-values
|
| 422 |
+
2. Measure total time
|
| 423 |
+
- **Expected Result**: Completes within 60 seconds
|
| 424 |
+
- **Success Criteria**: < 60s for 10 queries × 4 pipelines × 3 k-values
|
| 425 |
+
|
| 426 |
+
#### 4.2 Scalability
|
| 427 |
+
|
| 428 |
+
**Test Case: PERF-004 - Large Document Set**
|
| 429 |
+
- **Objective**: System handles 1000+ documents
|
| 430 |
+
- **Steps**:
|
| 431 |
+
1. Index 1000 documents
|
| 432 |
+
2. Perform searches
|
| 433 |
+
3. Monitor memory and CPU
|
| 434 |
+
- **Expected Result**: System remains responsive
|
| 435 |
+
- **Success Criteria**: Memory usage reasonable, search latency acceptable
|
| 436 |
+
|
| 437 |
+
**Test Case: PERF-005 - Concurrent Users**
|
| 438 |
+
- **Objective**: Multiple users don't degrade performance
|
| 439 |
+
- **Steps**:
|
| 440 |
+
1. Simulate 5 concurrent users
|
| 441 |
+
2. Each performs searches
|
| 442 |
+
3. Monitor performance
|
| 443 |
+
- **Expected Result**: No significant performance degradation
|
| 444 |
+
- **Success Criteria**: Latency stays within acceptable range
|
| 445 |
+
|
| 446 |
+
### 5. Integration Tests
|
| 447 |
+
|
| 448 |
+
#### 5.1 API Integration
|
| 449 |
+
|
| 450 |
+
**Test Case: INT-001 - Gradio API Access**
|
| 451 |
+
- **Objective**: Verify API endpoints work
|
| 452 |
+
- **Steps**:
|
| 453 |
+
1. Use Gradio Client to call API
|
| 454 |
+
2. Test build_rag API
|
| 455 |
+
3. Test search API
|
| 456 |
+
4. Test evaluate API
|
| 457 |
+
- **Expected Result**: All APIs return expected results
|
| 458 |
+
- **Success Criteria**: API endpoints functional
|
| 459 |
+
|
| 460 |
+
**Test Case: INT-002 - MCP Server Integration**
|
| 461 |
+
- **Objective**: Verify MCP server functions
|
| 462 |
+
- **Steps**:
|
| 463 |
+
1. Connect MCP client
|
| 464 |
+
2. Call MCP tools
|
| 465 |
+
3. Verify responses
|
| 466 |
+
- **Expected Result**: MCP tools work correctly
|
| 467 |
+
- **Success Criteria**: MCP integration functional
|
| 468 |
+
|
| 469 |
+
#### 5.2 Data Flow
|
| 470 |
+
|
| 471 |
+
**Test Case: INT-003 - End-to-End Workflow**
|
| 472 |
+
- **Objective**: Complete workflow from upload to evaluation
|
| 473 |
+
- **Steps**:
|
| 474 |
+
1. Upload documents
|
| 475 |
+
2. Build index
|
| 476 |
+
3. Search documents
|
| 477 |
+
4. Run evaluation
|
| 478 |
+
5. Download reports
|
| 479 |
+
- **Expected Result**: All steps complete successfully
|
| 480 |
+
- **Success Criteria**: Complete workflow functional
|
| 481 |
+
|
| 482 |
+
## Test Execution Guidelines
|
| 483 |
+
|
| 484 |
+
### Test Environment Setup
|
| 485 |
+
|
| 486 |
+
1. **Prerequisites**:
|
| 487 |
+
- Python 3.8+
|
| 488 |
+
- All dependencies installed (`pip install -r requirements.txt`)
|
| 489 |
+
- spaCy model downloaded (`python -m spacy download en_core_web_sm`)
|
| 490 |
+
- Sample documents available in `sample_documents/`
|
| 491 |
+
|
| 492 |
+
2. **Test Data**:
|
| 493 |
+
- Use provided sample documents
|
| 494 |
+
- Use `sample_evaluation_queries.json` for evaluation tests
|
| 495 |
+
- Use `sample_user_satisfaction.json` for satisfaction tests
|
| 496 |
+
|
| 497 |
+
3. **Test Execution**:
|
| 498 |
+
- Manual tests: Follow step-by-step instructions
|
| 499 |
+
- Automated tests: Run pytest (if implemented)
|
| 500 |
+
- Document results and any issues found
|
| 501 |
+
|
| 502 |
+
### Test Reporting
|
| 503 |
+
|
| 504 |
+
For each test case:
|
| 505 |
+
1. **Status**: Pass / Fail / Blocked / Not Tested
|
| 506 |
+
2. **Notes**: Observations, issues, screenshots
|
| 507 |
+
3. **Environment**: Browser, OS, Python version
|
| 508 |
+
4. **Date**: Test execution date
|
| 509 |
+
|
| 510 |
+
### Non-Technical User Testing
|
| 511 |
+
|
| 512 |
+
For UX tests (UX-001 to UX-005):
|
| 513 |
+
- Use actual non-technical users when possible
|
| 514 |
+
- Provide minimal instruction
|
| 515 |
+
- Observe user behavior
|
| 516 |
+
- Note confusion points
|
| 517 |
+
- Measure task completion time
|
| 518 |
+
|
| 519 |
+
## Success Criteria Summary
|
| 520 |
+
|
| 521 |
+
### Must Pass (Critical)
|
| 522 |
+
- All Accuracy Tests (TAG-001 to MET-004)
|
| 523 |
+
- All Robustness Tests (ROB-001 to ROB-013)
|
| 524 |
+
- Core UX Tests (UX-001, UX-002, UX-005)
|
| 525 |
+
|
| 526 |
+
### Should Pass (Important)
|
| 527 |
+
- Advanced UX Tests (UX-006 to UX-012)
|
| 528 |
+
- Performance Tests (PERF-001 to PERF-003)
|
| 529 |
+
- Integration Tests (INT-001 to INT-003)
|
| 530 |
+
|
| 531 |
+
### Nice to Have
|
| 532 |
+
- Scalability Tests (PERF-004, PERF-005)
|
| 533 |
+
- Edge case handling
|
| 534 |
+
- Performance optimizations
|
| 535 |
+
|
| 536 |
+
## Test Maintenance
|
| 537 |
+
|
| 538 |
+
- Update test cases when features change
|
| 539 |
+
- Add new test cases for new features
|
| 540 |
+
- Review and refine test cases quarterly
|
| 541 |
+
- Keep test data updated
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
|
EVALUATION_GUIDE.md
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluation Guide - How to Run Evaluation with Samples
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The **Analytics & Evaluation** tab allows you to run comprehensive quantitative evaluation of all 4 retrieval methods using test queries with ground truth documents.
|
| 6 |
+
|
| 7 |
+
## Input Format
|
| 8 |
+
|
| 9 |
+
### 1. Evaluation Queries (JSON)
|
| 10 |
+
|
| 11 |
+
**Required Format:**
|
| 12 |
+
```json
|
| 13 |
+
[
|
| 14 |
+
{
|
| 15 |
+
"query": "Your question here",
|
| 16 |
+
"ground_truth": ["chunk_content_1", "chunk_content_2"],
|
| 17 |
+
"k_values": [1, 3, 5],
|
| 18 |
+
"tags": ["tag1", "tag2"],
|
| 19 |
+
"tag_operator": "OR",
|
| 20 |
+
"vector_weight": 0.7,
|
| 21 |
+
"tag_weight": 0.3
|
| 22 |
+
}
|
| 23 |
+
]
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
**Fields:**
|
| 27 |
+
- **`query`** (required): The search question/query string
|
| 28 |
+
- **`ground_truth`** (required): List of actual document contents that should be retrieved. These should match the **actual text content** of chunks in your indexed documents.
|
| 29 |
+
- **`k_values`** (optional): List of k values to test (default: `[1, 3, 5]`)
|
| 30 |
+
- **`tags`** (optional): Tags for tag-based pipelines
|
| 31 |
+
- **`tag_operator`** (optional): `"OR"`, `"AND"`, or `"NOT"` (default: `"OR"`)
|
| 32 |
+
- **`vector_weight`** (optional): For hybrid pipelines (default: `0.7`)
|
| 33 |
+
- **`tag_weight`** (optional): For hybrid pipelines (default: `0.3`)
|
| 34 |
+
|
| 35 |
+
### 2. User Satisfaction Scores (JSON, Optional)
|
| 36 |
+
|
| 37 |
+
**Format:**
|
| 38 |
+
```json
|
| 39 |
+
{
|
| 40 |
+
"query_0": 4.5,
|
| 41 |
+
"query_1": 3.8,
|
| 42 |
+
"query_2": 5.0
|
| 43 |
+
}
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
- Keys are `"query_0"`, `"query_1"`, etc. (index-based)
|
| 47 |
+
- Values are satisfaction scores (typically 1-5)
|
| 48 |
+
|
| 49 |
+
## Sample Evaluation Input
|
| 50 |
+
|
| 51 |
+
### Example 1: Basic Evaluation
|
| 52 |
+
|
| 53 |
+
```json
|
| 54 |
+
[
|
| 55 |
+
{
|
| 56 |
+
"query": "What are the emergency procedures for fire incidents?",
|
| 57 |
+
"ground_truth": [
|
| 58 |
+
"In case of fire, immediately activate the nearest fire alarm and evacuate the building following the posted exit routes.",
|
| 59 |
+
"Fire safety protocols require all personnel to know the location of fire extinguishers and emergency exits.",
|
| 60 |
+
"During fire emergencies, do not use elevators and stay low to avoid smoke inhalation."
|
| 61 |
+
],
|
| 62 |
+
"k_values": [1, 3, 5]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"query": "What equipment is needed for patient safety monitoring?",
|
| 66 |
+
"ground_truth": [
|
| 67 |
+
"Standard patient monitoring equipment includes blood pressure cuffs, pulse oximeters, and ECG monitors.",
|
| 68 |
+
"Safety monitoring requires regular calibration of medical devices and documented maintenance logs."
|
| 69 |
+
],
|
| 70 |
+
"k_values": [1, 3, 5]
|
| 71 |
+
}
|
| 72 |
+
]
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Example 2: With Tags
|
| 76 |
+
|
| 77 |
+
```json
|
| 78 |
+
[
|
| 79 |
+
{
|
| 80 |
+
"query": "What are surgical safety protocols?",
|
| 81 |
+
"ground_truth": [
|
| 82 |
+
"All surgical procedures require pre-operative checklists and sterile environment protocols.",
|
| 83 |
+
"Surgical safety includes patient identification verification and site marking procedures.",
|
| 84 |
+
"Post-operative care involves monitoring vital signs and wound care instructions."
|
| 85 |
+
],
|
| 86 |
+
"k_values": [1, 3, 5],
|
| 87 |
+
"tags": ["surgery", "safety", "protocol"],
|
| 88 |
+
"tag_operator": "AND"
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"query": "How to handle medical emergencies?",
|
| 92 |
+
"ground_truth": [
|
| 93 |
+
"Medical emergency response begins with assessing patient ABC (Airway, Breathing, Circulation).",
|
| 94 |
+
"Emergency protocols require immediate notification of medical team and preparation of emergency equipment."
|
| 95 |
+
],
|
| 96 |
+
"k_values": [1, 3, 5],
|
| 97 |
+
"tags": ["emergency", "medical", "response"],
|
| 98 |
+
"tag_operator": "OR"
|
| 99 |
+
}
|
| 100 |
+
]
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### Example 3: With User Satisfaction
|
| 104 |
+
|
| 105 |
+
**Evaluation Queries:**
|
| 106 |
+
```json
|
| 107 |
+
[
|
| 108 |
+
{
|
| 109 |
+
"query": "What are infection control measures?",
|
| 110 |
+
"ground_truth": [
|
| 111 |
+
"Infection control requires hand hygiene, use of personal protective equipment, and proper sterilization of instruments.",
|
| 112 |
+
"Standard precautions must be followed for all patients to prevent transmission of infectious diseases."
|
| 113 |
+
],
|
| 114 |
+
"k_values": [1, 3, 5]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"query": "What are patient care guidelines?",
|
| 118 |
+
"ground_truth": [
|
| 119 |
+
"Patient care guidelines emphasize respect for patient autonomy, informed consent, and maintaining confidentiality.",
|
| 120 |
+
"Care protocols require documentation of all interventions and regular assessment of patient condition."
|
| 121 |
+
],
|
| 122 |
+
"k_values": [1, 3, 5]
|
| 123 |
+
}
|
| 124 |
+
]
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
**User Satisfaction Scores:**
|
| 128 |
+
```json
|
| 129 |
+
{
|
| 130 |
+
"query_0": 4.5,
|
| 131 |
+
"query_1": 4.2
|
| 132 |
+
}
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
## Step-by-Step Instructions
|
| 136 |
+
|
| 137 |
+
### Step 1: Upload Documents
|
| 138 |
+
1. Go to **Upload & Tagging** tab
|
| 139 |
+
2. Upload your PDF/TXT documents
|
| 140 |
+
3. Click **"Build RAG Index"**
|
| 141 |
+
4. Wait for indexing to complete
|
| 142 |
+
|
| 143 |
+
### Step 2: Prepare Ground Truth
|
| 144 |
+
**Important:** Ground truth must match the **actual text content** of chunks in your indexed documents.
|
| 145 |
+
|
| 146 |
+
**How to find ground truth:**
|
| 147 |
+
1. Use **Search & Compare** tab to search for similar queries
|
| 148 |
+
2. Check the retrieved document content
|
| 149 |
+
3. Copy the exact text from relevant chunks
|
| 150 |
+
4. Use these as your `ground_truth` array
|
| 151 |
+
|
| 152 |
+
**Example:**
|
| 153 |
+
If a chunk contains:
|
| 154 |
+
```
|
| 155 |
+
"Fire safety protocols require all personnel to know the location of fire extinguishers and emergency exits."
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
Then use:
|
| 159 |
+
```json
|
| 160 |
+
"ground_truth": ["Fire safety protocols require all personnel to know the location of fire extinguishers and emergency exits."]
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
### Step 3: Enter Evaluation Queries
|
| 164 |
+
1. Go to **Analytics & Evaluation** tab
|
| 165 |
+
2. In **"Evaluation Queries (JSON)"** field, paste your JSON array
|
| 166 |
+
3. Use the sample format above as a template
|
| 167 |
+
|
| 168 |
+
### Step 4: (Optional) Add User Satisfaction
|
| 169 |
+
1. In **"User Satisfaction Scores (JSON, optional)"** field
|
| 170 |
+
2. Enter satisfaction scores as JSON object
|
| 171 |
+
3. Use `query_0`, `query_1`, etc. as keys
|
| 172 |
+
|
| 173 |
+
### Step 5: Set Output Filename
|
| 174 |
+
1. In **"Output Filename"** field
|
| 175 |
+
2. Enter filename (e.g., `evaluation_results.csv`)
|
| 176 |
+
3. Results will be saved to `reports/` directory
|
| 177 |
+
|
| 178 |
+
### Step 6: Run Evaluation
|
| 179 |
+
1. Click **"Run Evaluation"** button
|
| 180 |
+
2. Wait for evaluation to complete (may take several minutes)
|
| 181 |
+
3. Results will appear in:
|
| 182 |
+
- **Evaluation Status**: Summary message
|
| 183 |
+
- **Evaluation Results**: DataFrame with all metrics
|
| 184 |
+
- **Summary Statistics**: Aggregated metrics by pipeline
|
| 185 |
+
- **Visualization Tabs**: Charts and graphs
|
| 186 |
+
|
| 187 |
+
## Understanding Results
|
| 188 |
+
|
| 189 |
+
### Metrics Explained
|
| 190 |
+
|
| 191 |
+
- **Precision@k**: Fraction of retrieved documents that are relevant
|
| 192 |
+
- Range: 0.0 - 1.0 (higher is better)
|
| 193 |
+
- Example: 0.8 means 80% of retrieved docs are relevant
|
| 194 |
+
|
| 195 |
+
- **nDCG@k**: Normalized Discounted Cumulative Gain
|
| 196 |
+
- Range: 0.0 - 1.0 (higher is better)
|
| 197 |
+
- Measures ranking quality with position weighting
|
| 198 |
+
|
| 199 |
+
- **Hit@k**: Whether at least one relevant document is in top-k
|
| 200 |
+
- Value: 0.0 or 1.0 (1.0 = found at least one relevant doc)
|
| 201 |
+
|
| 202 |
+
- **MRR**: Mean Reciprocal Rank
|
| 203 |
+
- Range: 0.0 - 1.0 (higher is better)
|
| 204 |
+
- Average of 1/rank where first relevant doc appears
|
| 205 |
+
|
| 206 |
+
- **Semantic Similarity**: Average cosine similarity between query and retrieved docs
|
| 207 |
+
- Range: 0.0 - 1.0 (higher is better)
|
| 208 |
+
|
| 209 |
+
- **Latency**: Response time in seconds (lower is better)
|
| 210 |
+
|
| 211 |
+
- **User Satisfaction**: Average satisfaction score (if provided)
|
| 212 |
+
- Range: depends on your scale (typically 1-5)
|
| 213 |
+
|
| 214 |
+
### Results DataFrame
|
| 215 |
+
|
| 216 |
+
Columns include:
|
| 217 |
+
- `query_id`: Query identifier
|
| 218 |
+
- `query`: Query text
|
| 219 |
+
- `pipeline`: Pipeline name (base_rag, tag_filter_rag, hybrid_rag, hybrid_rerank_rag)
|
| 220 |
+
- `k`: Number of results requested
|
| 221 |
+
- `precision_at_k`: Precision metric
|
| 222 |
+
- `ndcg_at_k`: nDCG metric
|
| 223 |
+
- `hit_at_k`: Hit metric
|
| 224 |
+
- `mrr`: MRR metric
|
| 225 |
+
- `semantic_similarity`: Similarity score
|
| 226 |
+
- `latency`: Response time
|
| 227 |
+
- `retrieved_count`: Number of documents retrieved
|
| 228 |
+
- `user_satisfaction`: Satisfaction score (if provided)
|
| 229 |
+
|
| 230 |
+
## Common Issues and Solutions
|
| 231 |
+
|
| 232 |
+
### Issue 1: "No results found" or Low Precision
|
| 233 |
+
|
| 234 |
+
**Problem:** Ground truth doesn't match indexed documents
|
| 235 |
+
|
| 236 |
+
**Solution:**
|
| 237 |
+
1. Check that ground truth text **exactly matches** chunk content
|
| 238 |
+
2. Use **Search & Compare** to verify what's actually indexed
|
| 239 |
+
3. Copy exact text from retrieved chunks
|
| 240 |
+
|
| 241 |
+
### Issue 2: "Invalid JSON format"
|
| 242 |
+
|
| 243 |
+
**Problem:** JSON syntax error
|
| 244 |
+
|
| 245 |
+
**Solution:**
|
| 246 |
+
1. Validate JSON using an online JSON validator
|
| 247 |
+
2. Ensure all strings are in double quotes `"`, not single quotes `'`
|
| 248 |
+
3. Ensure no trailing commas
|
| 249 |
+
4. Check brackets and braces are balanced
|
| 250 |
+
|
| 251 |
+
### Issue 3: Evaluation Takes Too Long
|
| 252 |
+
|
| 253 |
+
**Problem:** Too many queries or high k values
|
| 254 |
+
|
| 255 |
+
**Solution:**
|
| 256 |
+
1. Start with 2-3 queries
|
| 257 |
+
2. Use lower k values (e.g., `[1, 3]` instead of `[1, 3, 5, 10]`)
|
| 258 |
+
3. Evaluation runs sequentially - be patient
|
| 259 |
+
|
| 260 |
+
### Issue 4: All Metrics Are Zero
|
| 261 |
+
|
| 262 |
+
**Problem:** Ground truth doesn't match any retrieved documents
|
| 263 |
+
|
| 264 |
+
**Solution:**
|
| 265 |
+
1. Verify documents are actually indexed (check document count)
|
| 266 |
+
2. Check that ground truth text matches indexed chunk content exactly
|
| 267 |
+
3. Use semantic matching threshold (system uses ~0.8 similarity threshold)
|
| 268 |
+
|
| 269 |
+
## Tips for Better Evaluation
|
| 270 |
+
|
| 271 |
+
1. **Start Small**: Begin with 2-3 queries to test the format
|
| 272 |
+
2. **Verify Ground Truth**: Always check what's actually indexed before creating ground truth
|
| 273 |
+
3. **Use Representative Queries**: Include queries that reflect real user needs
|
| 274 |
+
4. **Test Different k Values**: Try `[1, 3, 5]` to see how results improve with more documents
|
| 275 |
+
5. **Compare Methods**: Use evaluation to see which pipeline performs best for your data
|
| 276 |
+
6. **Include Edge Cases**: Test with queries that might not have perfect matches
|
| 277 |
+
|
| 278 |
+
## Output Files
|
| 279 |
+
|
| 280 |
+
Evaluation generates several files in `reports/` directory:
|
| 281 |
+
|
| 282 |
+
1. **CSV File**: `evaluation_results.csv` - Detailed metrics per query/pipeline/k
|
| 283 |
+
2. **JSON File**: `evaluation_results.json` - Complete results with summary
|
| 284 |
+
3. **PNG Charts**: Various visualization charts in `reports/visualizations/`
|
| 285 |
+
4. **HTML Report**: Comprehensive report with embedded charts
|
| 286 |
+
|
| 287 |
+
## Sample Workflow
|
| 288 |
+
|
| 289 |
+
1. **Upload documents** → Index with tags
|
| 290 |
+
2. **Search manually** → Find relevant chunks
|
| 291 |
+
3. **Create queries** → Based on document topics
|
| 292 |
+
4. **Extract ground truth** → Copy exact chunk text
|
| 293 |
+
5. **Run evaluation** → Get quantitative metrics
|
| 294 |
+
6. **Analyze results** → Compare pipeline performance
|
| 295 |
+
7. **Iterate** → Refine queries and ground truth
|
| 296 |
+
|
| 297 |
+
## Quick Reference
|
| 298 |
+
|
| 299 |
+
**Minimal Valid Input:**
|
| 300 |
+
```json
|
| 301 |
+
[
|
| 302 |
+
{
|
| 303 |
+
"query": "Your question",
|
| 304 |
+
"ground_truth": ["Exact chunk text 1", "Exact chunk text 2"]
|
| 305 |
+
}
|
| 306 |
+
]
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
**Full Input Example:**
|
| 310 |
+
```json
|
| 311 |
+
[
|
| 312 |
+
{
|
| 313 |
+
"query": "What are safety protocols?",
|
| 314 |
+
"ground_truth": ["Safety protocol text from indexed document"],
|
| 315 |
+
"k_values": [1, 3, 5],
|
| 316 |
+
"tags": ["safety", "protocol"],
|
| 317 |
+
"tag_operator": "OR"
|
| 318 |
+
}
|
| 319 |
+
]
|
| 320 |
+
```
|
| 321 |
+
|
| 322 |
+
Remember: Ground truth must **exactly match** the content of your indexed document chunks!
|
| 323 |
+
|
HUGGINGFACE_DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces Deployment Guide
|
| 2 |
+
|
| 3 |
+
Complete guide for deploying Auto Tagging RAG System to Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## Quick Start
|
| 6 |
+
|
| 7 |
+
### Prerequisites
|
| 8 |
+
|
| 9 |
+
1. **Hugging Face Account**: Sign up at [huggingface.co](https://huggingface.co)
|
| 10 |
+
2. **Git**: Installed and configured
|
| 11 |
+
3. **Repository**: Your project code pushed to a Git repository
|
| 12 |
+
|
| 13 |
+
### Step 1: Create a New Space
|
| 14 |
+
|
| 15 |
+
1. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
|
| 16 |
+
2. Click **"Create new Space"**
|
| 17 |
+
3. Configure:
|
| 18 |
+
- **Space name**: `auto-tagging-rag` (or your preferred name)
|
| 19 |
+
- **SDK**: Select **Gradio**
|
| 20 |
+
- **Visibility**: Public or Private
|
| 21 |
+
- **Hardware**: CPU (free) or GPU (paid) if needed
|
| 22 |
+
|
| 23 |
+
### Step 2: Clone and Setup
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
# Clone your Space repository
|
| 27 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/auto-tagging-rag
|
| 28 |
+
cd auto-tagging-rag
|
| 29 |
+
|
| 30 |
+
# Copy your project files to the Space repository
|
| 31 |
+
cp -r /path/to/your/auto_tagging_rag/* .
|
| 32 |
+
cp /path/to/your/auto_tagging_rag/.gitignore . # If exists
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Step 3: Verify Files Structure
|
| 36 |
+
|
| 37 |
+
Ensure these files are present:
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
.
|
| 41 |
+
├── app.py # Gradio application entry point
|
| 42 |
+
├── requirements.txt # Python dependencies
|
| 43 |
+
├── README.md # Space description
|
| 44 |
+
└── core/ # Core modules directory
|
| 45 |
+
├── __init__.py
|
| 46 |
+
├── ingest.py
|
| 47 |
+
├── index.py
|
| 48 |
+
├── retrieval.py
|
| 49 |
+
├── eval.py
|
| 50 |
+
├── tag_generator.py
|
| 51 |
+
├── reranker.py
|
| 52 |
+
├── comparison.py
|
| 53 |
+
├── visualization.py
|
| 54 |
+
├── report_generator.py
|
| 55 |
+
├── session_manager.py
|
| 56 |
+
├── session_rag.py
|
| 57 |
+
└── utils.py
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### Step 4: Update README.md for Spaces
|
| 61 |
+
|
| 62 |
+
Your `README.md` should include Hugging Face Spaces frontmatter (already included):
|
| 63 |
+
|
| 64 |
+
```markdown
|
| 65 |
+
---
|
| 66 |
+
title: Auto Tagging RAG System
|
| 67 |
+
emoji: 📚
|
| 68 |
+
colorFrom: indigo
|
| 69 |
+
colorTo: blue
|
| 70 |
+
sdk: gradio
|
| 71 |
+
sdk_version: 5.49.1
|
| 72 |
+
app_file: app.py
|
| 73 |
+
pinned: false
|
| 74 |
+
license: mit
|
| 75 |
+
python_version: "3.10"
|
| 76 |
+
---
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### Step 5: Configure Environment Variables
|
| 80 |
+
|
| 81 |
+
In your Hugging Face Space (if using OpenAI features):
|
| 82 |
+
|
| 83 |
+
1. Go to **Settings** → **Repository secrets**
|
| 84 |
+
2. Add security-important variables:
|
| 85 |
+
- `OPENAI_API_KEY`: Your OpenAI API key (required only if using OpenAI embeddings or tag generation)
|
| 86 |
+
- Get from: [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys)
|
| 87 |
+
|
| 88 |
+
**Note**:
|
| 89 |
+
- Only `OPENAI_API_KEY` is security-sensitive and should be set as a secret
|
| 90 |
+
- All other configuration variables have sensible defaults
|
| 91 |
+
- Environment variables are automatically injected into the Space container
|
| 92 |
+
|
| 93 |
+
### Step 6: Install Dependencies
|
| 94 |
+
|
| 95 |
+
Ensure `requirements.txt` includes all necessary packages:
|
| 96 |
+
|
| 97 |
+
```txt
|
| 98 |
+
gradio==5.49.1
|
| 99 |
+
gradio-client==1.13.3
|
| 100 |
+
langchain>=0.1.0
|
| 101 |
+
langchain-community>=0.0.0
|
| 102 |
+
chromadb>=0.4.0
|
| 103 |
+
pypdf>=3.0.0
|
| 104 |
+
PyPDF2>=3.0.0
|
| 105 |
+
sentence-transformers>=2.2.0
|
| 106 |
+
tiktoken>=0.5.0
|
| 107 |
+
yake>=0.4.0
|
| 108 |
+
keybert>=0.8.0
|
| 109 |
+
spacy>=3.7.0
|
| 110 |
+
janome>=0.5.0
|
| 111 |
+
openai>=1.0.0
|
| 112 |
+
pytest>=7.0.0
|
| 113 |
+
pytest-asyncio>=0.21.0
|
| 114 |
+
python-dotenv>=1.0.0
|
| 115 |
+
PyYAML>=6.0
|
| 116 |
+
numpy>=1.21.0
|
| 117 |
+
pandas>=1.5.0
|
| 118 |
+
scikit-learn>=1.2.0
|
| 119 |
+
matplotlib>=3.5.0
|
| 120 |
+
jinja2>=3.1.0
|
| 121 |
+
mcp>=1.0.0
|
| 122 |
+
fastapi>=0.110.0
|
| 123 |
+
starlette>=0.36.3
|
| 124 |
+
uvicorn>=0.23.0
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
**Important for Hugging Face Spaces**:
|
| 128 |
+
- Models (SentenceTransformers, spaCy) are downloaded automatically on first run
|
| 129 |
+
- No manual model download needed
|
| 130 |
+
- Models are cached in `/tmp` or container storage
|
| 131 |
+
|
| 132 |
+
### Step 7: Handle Model Downloads in Code
|
| 133 |
+
|
| 134 |
+
Your `app.py` should handle model downloads gracefully. The system already does this:
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
# In core/index.py - SentenceTransformers uses local_files_only=False on first run
|
| 138 |
+
# In core/tag_generator.py - spaCy models are loaded with error handling
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
**For spaCy English model**: The code automatically downloads it on first use if not available locally.
|
| 142 |
+
|
| 143 |
+
### Step 8: Configure Persistence Directory
|
| 144 |
+
|
| 145 |
+
For Hugging Face Spaces, use a writable directory:
|
| 146 |
+
|
| 147 |
+
```python
|
| 148 |
+
# In app.py or core modules
|
| 149 |
+
import os
|
| 150 |
+
|
| 151 |
+
# Default for Spaces
|
| 152 |
+
PERSIST_DIR = os.getenv("CHROMA_PERSIST_DIR", "/tmp/chroma_data")
|
| 153 |
+
|
| 154 |
+
# For Spaces, /tmp is writable and persists during container lifecycle
|
| 155 |
+
# Note: Data is cleared on Space restart
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### Step 9: Commit and Push
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
git add .
|
| 162 |
+
git commit -m "Initial commit: Auto Tagging RAG System"
|
| 163 |
+
git push
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
### Step 10: Build and Deploy
|
| 167 |
+
|
| 168 |
+
1. Hugging Face Spaces will automatically build your Space
|
| 169 |
+
2. Check the **Logs** tab for build progress
|
| 170 |
+
3. Wait for build to complete (usually 2-5 minutes)
|
| 171 |
+
4. Your Space will be available at: `https://huggingface.co/spaces/YOUR_USERNAME/auto-tagging-rag`
|
| 172 |
+
|
| 173 |
+
## Configuration for Spaces
|
| 174 |
+
|
| 175 |
+
### Environment Variables in Space Settings
|
| 176 |
+
|
| 177 |
+
Go to **Settings** → **Repository secrets** and add (if using OpenAI):
|
| 178 |
+
|
| 179 |
+
| Variable | Description | Required | Example |
|
| 180 |
+
|----------|-------------|----------|---------|
|
| 181 |
+
| `OPENAI_API_KEY` | OpenAI API key for embeddings/tag generation | Only if using OpenAI | `sk-...` |
|
| 182 |
+
|
| 183 |
+
**Security Note**:
|
| 184 |
+
- Only `OPENAI_API_KEY` is security-sensitive and should be set as a Repository Secret
|
| 185 |
+
- All other configuration variables have sensible defaults and don't need to be set
|
| 186 |
+
- See `SETUP_GUIDE.md` for optional configuration variables
|
| 187 |
+
|
| 188 |
+
### Hardware Requirements
|
| 189 |
+
|
| 190 |
+
**CPU (Free Tier)**:
|
| 191 |
+
- Sufficient for small to medium datasets
|
| 192 |
+
- Good for testing and demonstrations
|
| 193 |
+
- Model downloads may take longer on first run
|
| 194 |
+
|
| 195 |
+
**GPU (Paid Tier)**:
|
| 196 |
+
- Recommended for large datasets or production use
|
| 197 |
+
- Faster embedding generation
|
| 198 |
+
- Better reranking performance
|
| 199 |
+
|
| 200 |
+
### Storage Considerations
|
| 201 |
+
|
| 202 |
+
**Important Notes**:
|
| 203 |
+
- Spaces have limited persistent storage
|
| 204 |
+
- ChromaDB data stored in `/tmp` is cleared on Space restart
|
| 205 |
+
- For persistent data, consider:
|
| 206 |
+
- Using Hugging Face Datasets (for document storage)
|
| 207 |
+
- External database (e.g., PostgreSQL via API)
|
| 208 |
+
- Hugging Face Hub for model artifacts
|
| 209 |
+
|
| 210 |
+
## Troubleshooting
|
| 211 |
+
|
| 212 |
+
### Build Failures
|
| 213 |
+
|
| 214 |
+
**Issue**: Build fails with "Module not found"
|
| 215 |
+
- **Solution**: Check `requirements.txt` includes all dependencies
|
| 216 |
+
- Verify Python version matches `python_version: "3.10"` in README
|
| 217 |
+
|
| 218 |
+
**Issue**: spaCy model not found
|
| 219 |
+
- **Solution**: The code automatically downloads models on first run
|
| 220 |
+
- Check logs for download progress
|
| 221 |
+
- If persistent, add `python -m spacy download en_core_web_sm` to build process (not recommended - handled in code)
|
| 222 |
+
|
| 223 |
+
### Runtime Errors
|
| 224 |
+
|
| 225 |
+
**Issue**: "Permission denied" errors with `/data`
|
| 226 |
+
- **Solution**: Use `/tmp/chroma_data` instead (already configured)
|
| 227 |
+
|
| 228 |
+
**Issue**: Out of memory errors
|
| 229 |
+
- **Solution**: Upgrade to GPU Space or reduce `MAX_TAGS_PER_CHUNK` and `k` values
|
| 230 |
+
|
| 231 |
+
**Issue**: Models not loading
|
| 232 |
+
- **Solution**: Check internet connectivity during first run
|
| 233 |
+
- Models download automatically on first use
|
| 234 |
+
- Subsequent runs use cached models
|
| 235 |
+
|
| 236 |
+
### Performance Issues
|
| 237 |
+
|
| 238 |
+
**Issue**: Slow first load
|
| 239 |
+
- **Solution**: Normal - models download on first run
|
| 240 |
+
- Subsequent loads are faster (models cached)
|
| 241 |
+
|
| 242 |
+
**Issue**: Timeout errors
|
| 243 |
+
- **Solution**: Increase Space timeout in settings
|
| 244 |
+
- Or reduce batch processing size
|
| 245 |
+
|
| 246 |
+
## Updating Your Space
|
| 247 |
+
|
| 248 |
+
```bash
|
| 249 |
+
# Make changes to your code
|
| 250 |
+
cd auto-tagging-rag
|
| 251 |
+
|
| 252 |
+
# Commit changes
|
| 253 |
+
git add .
|
| 254 |
+
git commit -m "Update: Description of changes"
|
| 255 |
+
git push
|
| 256 |
+
|
| 257 |
+
# Hugging Face Spaces will automatically rebuild
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
## Public vs Private Spaces
|
| 261 |
+
|
| 262 |
+
**Public Spaces**:
|
| 263 |
+
- Accessible to everyone
|
| 264 |
+
- Great for demos and sharing
|
| 265 |
+
- Free hosting (with usage limits)
|
| 266 |
+
|
| 267 |
+
**Private Spaces**:
|
| 268 |
+
- Requires Hugging Face Pro subscription
|
| 269 |
+
- Access restricted to authorized users
|
| 270 |
+
- Better for internal use cases
|
| 271 |
+
|
| 272 |
+
## Best Practices
|
| 273 |
+
|
| 274 |
+
1. **Optimize Model Loading**: Models load on first use - consider lazy loading
|
| 275 |
+
2. **Error Handling**: Add comprehensive error handling for network issues
|
| 276 |
+
3. **User Guidance**: Add clear instructions in README for users
|
| 277 |
+
4. **Resource Management**: Monitor memory usage - Spaces have limits
|
| 278 |
+
5. **Session Management**: Sessions persist in `/tmp` but clear on restart
|
| 279 |
+
|
| 280 |
+
## Example Space Configuration
|
| 281 |
+
|
| 282 |
+
### README.md Frontmatter
|
| 283 |
+
|
| 284 |
+
```markdown
|
| 285 |
+
---
|
| 286 |
+
title: Auto Tagging RAG System
|
| 287 |
+
emoji: 📚
|
| 288 |
+
colorFrom: indigo
|
| 289 |
+
colorTo: blue
|
| 290 |
+
sdk: gradio
|
| 291 |
+
sdk_version: 5.49.1
|
| 292 |
+
app_file: app.py
|
| 293 |
+
pinned: false
|
| 294 |
+
license: mit
|
| 295 |
+
python_version: "3.10"
|
| 296 |
+
---
|
| 297 |
+
|
| 298 |
+
# Auto Tagging RAG System
|
| 299 |
+
|
| 300 |
+
[Your project description]
|
| 301 |
+
```
|
| 302 |
+
|
| 303 |
+
### .gitignore
|
| 304 |
+
|
| 305 |
+
```
|
| 306 |
+
__pycache__/
|
| 307 |
+
*.pyc
|
| 308 |
+
*.pyo
|
| 309 |
+
*.pyd
|
| 310 |
+
.Python
|
| 311 |
+
chroma_data/
|
| 312 |
+
reports/
|
| 313 |
+
*.log
|
| 314 |
+
.env
|
| 315 |
+
.DS_Store
|
| 316 |
+
*.swp
|
| 317 |
+
*.swo
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
## Support
|
| 321 |
+
|
| 322 |
+
- **Documentation**: See `README.md` and `SETUP_GUIDE.md`
|
| 323 |
+
- **Issues**: Report on GitHub or Hugging Face Space
|
| 324 |
+
- **Community**: Hugging Face Spaces discussions
|
| 325 |
+
|
MODEL_DB_CONFIG.md
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model & Database Configuration Notes
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This document details all libraries, models, APIs, and database configurations used in the Auto Tagging RAG System.
|
| 6 |
+
|
| 7 |
+
## Table of Contents
|
| 8 |
+
|
| 9 |
+
1. [Embedding Models](#embedding-models)
|
| 10 |
+
2. [Tag Generation Libraries](#tag-generation-libraries)
|
| 11 |
+
3. [Reranking Models](#reranking-models)
|
| 12 |
+
4. [Database Configuration](#database-configuration)
|
| 13 |
+
5. [API Integrations](#api-integrations)
|
| 14 |
+
6. [Model Download & Caching](#model-download--caching)
|
| 15 |
+
|
| 16 |
+
## Embedding Models
|
| 17 |
+
|
| 18 |
+
### Primary: SentenceTransformers
|
| 19 |
+
|
| 20 |
+
**Library**: `sentence-transformers>=2.2.0`
|
| 21 |
+
|
| 22 |
+
**Default Model**: `all-MiniLM-L6-v2`
|
| 23 |
+
- **Provider**: Hugging Face
|
| 24 |
+
- **Dimensions**: 384
|
| 25 |
+
- **Size**: ~80MB
|
| 26 |
+
- **Language**: Multilingual (optimized for English)
|
| 27 |
+
- **Performance**: Fast, good quality for most use cases
|
| 28 |
+
- **Download**: Automatic on first use
|
| 29 |
+
- **Location**: Cached in `~/.cache/huggingface/` or `~/.cache/torch/sentence_transformers/`
|
| 30 |
+
|
| 31 |
+
**Usage**:
|
| 32 |
+
```python
|
| 33 |
+
from sentence_transformers import SentenceTransformer
|
| 34 |
+
model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
| 35 |
+
embeddings = model.encode(["Your text here"])
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
**Configuration**:
|
| 39 |
+
- Set via `ST_EMBED_MODEL` environment variable
|
| 40 |
+
- Alternative models: `all-mpnet-base-v2` (768 dims), `paraphrase-multilingual-MiniLM-L12-v2` (384 dims)
|
| 41 |
+
|
| 42 |
+
**Offline Mode**:
|
| 43 |
+
- Uses `local_files_only=True` to prevent network access after initial download
|
| 44 |
+
- Model is cached locally for offline operation
|
| 45 |
+
|
| 46 |
+
### Alternative: OpenAI Embeddings
|
| 47 |
+
|
| 48 |
+
**Library**: `openai>=1.0.0`
|
| 49 |
+
|
| 50 |
+
**Default Model**: `text-embedding-3-small`
|
| 51 |
+
- **Provider**: OpenAI API
|
| 52 |
+
- **Dimensions**: 1536
|
| 53 |
+
- **Cost**: Pay-per-use
|
| 54 |
+
- **Rate Limits**: Varies by plan
|
| 55 |
+
- **Requires**: `OPENAI_API_KEY` environment variable
|
| 56 |
+
|
| 57 |
+
**Alternative Models**:
|
| 58 |
+
- `text-embedding-3-large`: 3072 dimensions, higher quality
|
| 59 |
+
- `text-embedding-ada-002`: Legacy model, 1536 dimensions
|
| 60 |
+
|
| 61 |
+
**Usage**:
|
| 62 |
+
```python
|
| 63 |
+
from openai import OpenAI
|
| 64 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 65 |
+
response = client.embeddings.create(
|
| 66 |
+
model="text-embedding-3-small",
|
| 67 |
+
input=["Your text here"]
|
| 68 |
+
)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**Configuration**:
|
| 72 |
+
- Set via `OPENAI_EMBED_MODEL` environment variable
|
| 73 |
+
- Force usage with `USE_OPENAI_EMBEDDINGS=true`
|
| 74 |
+
|
| 75 |
+
**ChromaDB Integration**:
|
| 76 |
+
- When using OpenAI embeddings, ChromaDB collections are namespaced: `documents__oai_1536`
|
| 77 |
+
- When using SentenceTransformers, collections are: `documents__st_384`
|
| 78 |
+
- Switching providers requires re-indexing documents
|
| 79 |
+
|
| 80 |
+
## Tag Generation Libraries
|
| 81 |
+
|
| 82 |
+
### YAKE (Yet Another Keyword Extractor)
|
| 83 |
+
|
| 84 |
+
**Library**: `yake>=0.4.0`
|
| 85 |
+
|
| 86 |
+
**Purpose**: Language-independent keyword extraction
|
| 87 |
+
- **Languages**: English, Japanese, and many others
|
| 88 |
+
- **Method**: Statistical approach based on word co-occurrence
|
| 89 |
+
- **Advantages**: No model downloads, fast, works offline
|
| 90 |
+
- **Configuration**: Controlled via `MAX_TAGS_PER_CHUNK`, `MIN_TAG_LENGTH`, `MAX_TAG_LENGTH`
|
| 91 |
+
|
| 92 |
+
**Usage**:
|
| 93 |
+
```python
|
| 94 |
+
import yake
|
| 95 |
+
kw_extractor = yake.KeywordExtractor(
|
| 96 |
+
lan="en", # or "ja"
|
| 97 |
+
n=3, # max words in phrase
|
| 98 |
+
top=10 # max tags
|
| 99 |
+
)
|
| 100 |
+
keywords = kw_extractor.extract_keywords(text)
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### KeyBERT
|
| 104 |
+
|
| 105 |
+
**Library**: `keybert>=0.8.0`
|
| 106 |
+
|
| 107 |
+
**Purpose**: Keyword extraction using BERT embeddings
|
| 108 |
+
- **Base Model**: Uses SentenceTransformers models (default: `all-MiniLM-L6-v2`)
|
| 109 |
+
- **Languages**: Works with multilingual BERT models
|
| 110 |
+
- **Method**: Extracts keywords by comparing document embeddings with candidate phrase embeddings
|
| 111 |
+
- **Advantages**: Higher quality than YAKE, leverages semantic understanding
|
| 112 |
+
|
| 113 |
+
**Usage**:
|
| 114 |
+
```python
|
| 115 |
+
from keybert import KeyBERT
|
| 116 |
+
kw_model = KeyBERT(model='all-MiniLM-L6-v2')
|
| 117 |
+
keywords = kw_model.extract_keywords(text, top_n=10)
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### spaCy
|
| 121 |
+
|
| 122 |
+
**Library**: `spacy>=3.7.0`
|
| 123 |
+
|
| 124 |
+
**Model**: `en_core_web_sm` (English)
|
| 125 |
+
- **Size**: ~13MB
|
| 126 |
+
- **Download**: `python -m spacy download en_core_web_sm`
|
| 127 |
+
- **Location**: Cached in spaCy data directory
|
| 128 |
+
|
| 129 |
+
**Purpose**: Noun phrase extraction for English
|
| 130 |
+
- **Method**: POS tagging + noun phrase chunking
|
| 131 |
+
- **Advantages**: High precision, grammar-aware
|
| 132 |
+
- **Limitation**: English only
|
| 133 |
+
|
| 134 |
+
**Usage**:
|
| 135 |
+
```python
|
| 136 |
+
import spacy
|
| 137 |
+
nlp = spacy.load("en_core_web_sm")
|
| 138 |
+
doc = nlp(text)
|
| 139 |
+
noun_phrases = [chunk.text for chunk in doc.noun_chunks]
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### Janome
|
| 143 |
+
|
| 144 |
+
**Library**: `janome>=0.5.0`
|
| 145 |
+
|
| 146 |
+
**Purpose**: Japanese morphological analysis
|
| 147 |
+
- **Languages**: Japanese only
|
| 148 |
+
- **Method**: MeCab-based tokenization and POS tagging
|
| 149 |
+
- **Advantages**: No external dependencies, pure Python
|
| 150 |
+
- **Download**: No model download required (built-in dictionaries)
|
| 151 |
+
|
| 152 |
+
**Usage**:
|
| 153 |
+
```python
|
| 154 |
+
from janome.tokenizer import Tokenizer
|
| 155 |
+
t = Tokenizer()
|
| 156 |
+
tokens = t.tokenize(text)
|
| 157 |
+
nouns = [token.surface for token in tokens if token.part_of_speech.startswith('名詞')]
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
### OpenAI Tag Generation
|
| 161 |
+
|
| 162 |
+
**Library**: `openai>=1.0.0`
|
| 163 |
+
|
| 164 |
+
**Model**: Configurable (default: `gpt-4o-mini`)
|
| 165 |
+
- **Purpose**: AI-generated tags using language model
|
| 166 |
+
- **Method**: Prompt-based generation
|
| 167 |
+
- **Advantages**: Highest quality, contextual understanding
|
| 168 |
+
- **Cost**: Pay-per-use
|
| 169 |
+
- **Requires**: `OPENAI_API_KEY`
|
| 170 |
+
|
| 171 |
+
**Usage**:
|
| 172 |
+
```python
|
| 173 |
+
from openai import OpenAI
|
| 174 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 175 |
+
response = client.chat.completions.create(
|
| 176 |
+
model="gpt-4o-mini",
|
| 177 |
+
messages=[{"role": "user", "content": f"Extract key tags from: {text}"}]
|
| 178 |
+
)
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
**Configuration**:
|
| 182 |
+
- Set model via `OPENAI_MODEL` environment variable
|
| 183 |
+
- Options: `gpt-4o-mini`, `gpt-4`, `gpt-3.5-turbo`
|
| 184 |
+
|
| 185 |
+
## Reranking Models
|
| 186 |
+
|
| 187 |
+
### Cross-Encoder Reranker
|
| 188 |
+
|
| 189 |
+
**Library**: `sentence-transformers` (cross-encoder models)
|
| 190 |
+
|
| 191 |
+
**Default Model**: `cross-encoder/ms-marco-MiniLM-L-6-v2`
|
| 192 |
+
- **Provider**: Hugging Face
|
| 193 |
+
- **Size**: ~40MB
|
| 194 |
+
- **Purpose**: Re-rank retrieved documents based on query-document relevance
|
| 195 |
+
- **Method**: Cross-attention between query and document
|
| 196 |
+
- **Download**: Automatic on first use
|
| 197 |
+
- **Performance**: Most accurate reranking method
|
| 198 |
+
|
| 199 |
+
**Usage**:
|
| 200 |
+
```python
|
| 201 |
+
from sentence_transformers import CrossEncoder
|
| 202 |
+
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 203 |
+
scores = model.predict([(query, doc) for doc in documents])
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
**Configuration**:
|
| 207 |
+
- Set model via `RERANKER_MODEL` environment variable
|
| 208 |
+
- Alternative models: `cross-encoder/ms-marco-electra-base`, `cross-encoder/quora-distilroberta-base`
|
| 209 |
+
|
| 210 |
+
### Semantic Reranking
|
| 211 |
+
|
| 212 |
+
**Method**: Uses embedding similarity (vector store)
|
| 213 |
+
- **Library**: `sentence-transformers` or `openai`
|
| 214 |
+
- **Purpose**: Re-rank based on semantic similarity
|
| 215 |
+
- **Advantages**: No additional model download
|
| 216 |
+
- **Performance**: Good but less accurate than cross-encoder
|
| 217 |
+
|
| 218 |
+
### Heuristic Reranking
|
| 219 |
+
|
| 220 |
+
**Method**: Rule-based scoring
|
| 221 |
+
- **Factors**: Tag overlap, keyword frequency, document length
|
| 222 |
+
- **Advantages**: Fast, no model required
|
| 223 |
+
- **Performance**: Baseline method
|
| 224 |
+
|
| 225 |
+
**Configuration**:
|
| 226 |
+
- Set strategy via `RERANKER_STRATEGY` environment variable
|
| 227 |
+
- Options: `cross_encoder`, `semantic`, `heuristic`, `openai`
|
| 228 |
+
|
| 229 |
+
## Database Configuration
|
| 230 |
+
|
| 231 |
+
### ChromaDB
|
| 232 |
+
|
| 233 |
+
**Library**: `chromadb>=0.4.0`
|
| 234 |
+
|
| 235 |
+
**Type**: Vector database with persistent storage
|
| 236 |
+
- **Storage Format**: SQLite (metadata) + binary files (vectors)
|
| 237 |
+
- **Location**: Configurable via `CHROMA_PERSIST_DIR` (default: `./chroma_data`)
|
| 238 |
+
- **Collection Naming**: Namespaced by embedding provider/dimension
|
| 239 |
+
- SentenceTransformers: `session_xxxxx__st_384`
|
| 240 |
+
- OpenAI: `session_xxxxx__oai_1536`
|
| 241 |
+
|
| 242 |
+
**Configuration**:
|
| 243 |
+
```python
|
| 244 |
+
import chromadb
|
| 245 |
+
client = chromadb.PersistentClient(path="./chroma_data")
|
| 246 |
+
collection = client.create_collection(
|
| 247 |
+
name="documents",
|
| 248 |
+
metadata={"hnsw:space": "cosine"}
|
| 249 |
+
)
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
**Metadata Schema**:
|
| 253 |
+
- `source_name`: Original filename
|
| 254 |
+
- `doc_id`: Unique document identifier
|
| 255 |
+
- `chunk_index`: Chunk index within document
|
| 256 |
+
- `tags`: Comma-separated tag string
|
| 257 |
+
- `lang`: Language code (en/ja)
|
| 258 |
+
- `user_tags`: User-provided tags (comma-separated)
|
| 259 |
+
|
| 260 |
+
**Storage Structure**:
|
| 261 |
+
```
|
| 262 |
+
chroma_data/
|
| 263 |
+
├── chroma.sqlite3 # Metadata database
|
| 264 |
+
├── session_xxxxx__st_384/ # Vector data for each collection
|
| 265 |
+
│ ├── data_level0.bin
|
| 266 |
+
│ ├── header.bin
|
| 267 |
+
│ ├── length.bin
|
| 268 |
+
│ └── link_lists.bin
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
**Session Isolation**:
|
| 272 |
+
- Each user session has separate ChromaDB collection
|
| 273 |
+
- Collections prefixed with `session_` + 8-char session ID
|
| 274 |
+
- Sessions persist across server restarts (collections stored on disk)
|
| 275 |
+
|
| 276 |
+
**Cleanup**:
|
| 277 |
+
- Sessions expire after `SESSION_TIMEOUT` seconds (default: 3600)
|
| 278 |
+
- Expired sessions are cleaned up automatically
|
| 279 |
+
- Manual cleanup: Delete `chroma_data/` directory
|
| 280 |
+
|
| 281 |
+
## API Integrations
|
| 282 |
+
|
| 283 |
+
### OpenAI API
|
| 284 |
+
|
| 285 |
+
**Endpoint**: `https://api.openai.com/v1/`
|
| 286 |
+
|
| 287 |
+
**Endpoints Used**:
|
| 288 |
+
1. **Embeddings**: `POST /embeddings`
|
| 289 |
+
- Models: `text-embedding-3-small`, `text-embedding-3-large`
|
| 290 |
+
- Rate limits: Varies by plan
|
| 291 |
+
|
| 292 |
+
2. **Chat Completions**: `POST /chat/completions`
|
| 293 |
+
- Models: `gpt-4o-mini`, `gpt-4`
|
| 294 |
+
- Used for: Tag generation, metadata detection
|
| 295 |
+
|
| 296 |
+
**Authentication**:
|
| 297 |
+
- API Key in `OPENAI_API_KEY` environment variable
|
| 298 |
+
- Format: `sk-...`
|
| 299 |
+
|
| 300 |
+
**Rate Limits**:
|
| 301 |
+
- Free tier: Limited requests/minute
|
| 302 |
+
- Paid tier: Higher limits
|
| 303 |
+
- Handling: Retries with exponential backoff
|
| 304 |
+
|
| 305 |
+
**Error Handling**:
|
| 306 |
+
- Network errors: Graceful fallback to SentenceTransformers
|
| 307 |
+
- Rate limits: Retry with backoff
|
| 308 |
+
- Invalid key: Log warning, continue with fallback
|
| 309 |
+
|
| 310 |
+
### Hugging Face Hub (Automatic)
|
| 311 |
+
|
| 312 |
+
**Purpose**: Model downloads via `sentence-transformers`
|
| 313 |
+
|
| 314 |
+
**Models Downloaded**:
|
| 315 |
+
- Embeddings: `all-MiniLM-L6-v2`
|
| 316 |
+
- Reranker: `cross-encoder/ms-marco-MiniLM-L-6-v2`
|
| 317 |
+
|
| 318 |
+
**Caching**:
|
| 319 |
+
- Location: `~/.cache/huggingface/` or `~/.cache/torch/sentence_transformers/`
|
| 320 |
+
- First download: Requires internet
|
| 321 |
+
- Subsequent runs: Uses cached models (offline)
|
| 322 |
+
|
| 323 |
+
**Offline Mode**:
|
| 324 |
+
- Set `local_files_only=True` in code
|
| 325 |
+
- Prevents network access after initial download
|
| 326 |
+
- Falls back gracefully if model not cached
|
| 327 |
+
|
| 328 |
+
## Model Download & Caching
|
| 329 |
+
|
| 330 |
+
### Automatic Downloads
|
| 331 |
+
|
| 332 |
+
The following models are downloaded automatically on first use:
|
| 333 |
+
|
| 334 |
+
1. **SentenceTransformers Embeddings** (~80MB)
|
| 335 |
+
- Triggered: First vector store initialization
|
| 336 |
+
- Location: `~/.cache/torch/sentence_transformers/all-MiniLM-L6-v2/`
|
| 337 |
+
|
| 338 |
+
2. **spaCy English Model** (~13MB)
|
| 339 |
+
- Triggered: Manual download required (`python -m spacy download en_core_web_sm`)
|
| 340 |
+
- Location: spaCy data directory
|
| 341 |
+
|
| 342 |
+
3. **Cross-Encoder Reranker** (~40MB)
|
| 343 |
+
- Triggered: First reranking operation
|
| 344 |
+
- Location: `~/.cache/torch/sentence_transformers/cross-encoder/ms-marco-MiniLM-L-6-v2/`
|
| 345 |
+
|
| 346 |
+
### Manual Pre-download
|
| 347 |
+
|
| 348 |
+
To download all models before first use:
|
| 349 |
+
|
| 350 |
+
```bash
|
| 351 |
+
# 1. spaCy English model
|
| 352 |
+
python -m spacy download en_core_web_sm
|
| 353 |
+
|
| 354 |
+
# 2. SentenceTransformers embeddings (via Python)
|
| 355 |
+
python3 -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
|
| 356 |
+
|
| 357 |
+
# 3. Cross-encoder reranker (via Python)
|
| 358 |
+
python3 -c "from sentence_transformers import CrossEncoder; CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')"
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
### Cache Management
|
| 362 |
+
|
| 363 |
+
**View Cached Models**:
|
| 364 |
+
```bash
|
| 365 |
+
# SentenceTransformers
|
| 366 |
+
ls ~/.cache/torch/sentence_transformers/
|
| 367 |
+
|
| 368 |
+
# spaCy
|
| 369 |
+
python -m spacy info en_core_web_sm
|
| 370 |
+
```
|
| 371 |
+
|
| 372 |
+
**Clear Cache** (if needed):
|
| 373 |
+
```bash
|
| 374 |
+
# SentenceTransformers
|
| 375 |
+
rm -rf ~/.cache/torch/sentence_transformers/
|
| 376 |
+
|
| 377 |
+
# spaCy (re-download)
|
| 378 |
+
python -m spacy download en_core_web_sm
|
| 379 |
+
```
|
| 380 |
+
|
| 381 |
+
### Disk Space Requirements
|
| 382 |
+
|
| 383 |
+
- **Minimum**: ~200MB for all models
|
| 384 |
+
- **Recommended**: ~500MB for cache and data
|
| 385 |
+
- **With Data**: Additional space for ChromaDB and reports
|
| 386 |
+
|
| 387 |
+
## Offline Operation
|
| 388 |
+
|
| 389 |
+
### Fully Offline After Setup
|
| 390 |
+
|
| 391 |
+
After initial model downloads, the system works **completely offline**:
|
| 392 |
+
|
| 393 |
+
1. ✅ All models cached locally
|
| 394 |
+
2. ✅ No API calls (unless OpenAI explicitly enabled)
|
| 395 |
+
3. ✅ No network access required
|
| 396 |
+
4. ✅ Tiktoken fallback for token counting
|
| 397 |
+
|
| 398 |
+
### Components Requiring Internet
|
| 399 |
+
|
| 400 |
+
- **Initial Setup**: Model downloads (one-time)
|
| 401 |
+
- **OpenAI Features**: If `OPENAI_API_KEY` is set (optional)
|
| 402 |
+
- **Gradio Telemetry**: Anonymous usage stats (can be disabled)
|
| 403 |
+
|
| 404 |
+
### Disable Network Access
|
| 405 |
+
|
| 406 |
+
Set in code:
|
| 407 |
+
- `local_files_only=True` for SentenceTransformers
|
| 408 |
+
- Remove `OPENAI_API_KEY` from environment
|
| 409 |
+
- Tiktoken falls back automatically if network unavailable
|
| 410 |
+
|
| 411 |
+
## Version Compatibility
|
| 412 |
+
|
| 413 |
+
### Python Version
|
| 414 |
+
- **Minimum**: Python 3.8
|
| 415 |
+
- **Recommended**: Python 3.10+
|
| 416 |
+
- **Tested**: Python 3.8, 3.9, 3.10, 3.11
|
| 417 |
+
|
| 418 |
+
### Library Versions
|
| 419 |
+
See `requirements.txt` for exact versions. Key libraries:
|
| 420 |
+
- `gradio>=5.49.1`
|
| 421 |
+
- `chromadb>=0.4.0`
|
| 422 |
+
- `sentence-transformers>=2.2.0`
|
| 423 |
+
- `spacy>=3.7.0`
|
| 424 |
+
|
| 425 |
+
### Model Compatibility
|
| 426 |
+
- All models tested with current library versions
|
| 427 |
+
- Breaking changes in model formats are handled automatically
|
| 428 |
+
- Updates: Check Hugging Face for model updates
|
| 429 |
+
|
| 430 |
+
## Performance Notes
|
| 431 |
+
|
| 432 |
+
### Model Loading Times
|
| 433 |
+
- SentenceTransformers: ~2-5 seconds (first load)
|
| 434 |
+
- spaCy: ~1-2 seconds (first load)
|
| 435 |
+
- Cross-encoder: ~2-5 seconds (first load)
|
| 436 |
+
- Subsequent loads: Cached, faster
|
| 437 |
+
|
| 438 |
+
### Memory Usage
|
| 439 |
+
- Base: ~500MB (Python + libraries)
|
| 440 |
+
- SentenceTransformers: ~200MB
|
| 441 |
+
- spaCy: ~100MB
|
| 442 |
+
- Cross-encoder: ~150MB
|
| 443 |
+
- ChromaDB: Varies with document count
|
| 444 |
+
- **Total**: ~1-2GB typical usage
|
| 445 |
+
|
| 446 |
+
### Optimization Tips
|
| 447 |
+
- Use smaller embedding models for lower memory
|
| 448 |
+
- Disable reranking if not needed
|
| 449 |
+
- Use CPU-only mode (default) - GPU optional
|
| 450 |
+
- Clear ChromaDB cache if disk space limited
|
| 451 |
+
|
README.md
CHANGED
|
@@ -1,12 +1,613 @@
|
|
| 1 |
---
|
| 2 |
-
title: Auto Tagging
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: indigo
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Auto Tagging RAG System
|
| 3 |
+
emoji: 📚
|
| 4 |
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
python_version: "3.10"
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Auto Tagging RAG System
|
| 15 |
+
|
| 16 |
+
A comprehensive system for evaluating multiple **RAG (Retrieval-Augmented Generation)** pipelines with support for flat tags, hybrid retrieval, reranking, and extensive evaluation metrics.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
### Retrieval Pipelines
|
| 21 |
+
- **Base RAG**: Standard vector similarity search
|
| 22 |
+
- **Tag Filter RAG**: Filter documents using flat tags with AND/OR/NOT operators
|
| 23 |
+
- **Hybrid RAG**: Combine vector search and tag-based filtering with weighted scoring
|
| 24 |
+
- **Hybrid Rerank RAG**: Apply reranking after hybrid retrieval for refined results
|
| 25 |
+
|
| 26 |
+
### Tag System
|
| 27 |
+
- **Flat Tags**: Automatic non-hierarchical tag generation using:
|
| 28 |
+
- YAKE (Yet Another Keyword Extractor)
|
| 29 |
+
- KeyBERT (keyword extraction with BERT)
|
| 30 |
+
- spaCy (noun phrase extraction for English)
|
| 31 |
+
- Janome (Japanese tokenization and noun extraction)
|
| 32 |
+
- OpenAI-based generation (optional)
|
| 33 |
+
- **Manual Tag Input**: Users can add custom tags during document upload
|
| 34 |
+
- Manual tags are combined with auto-generated tags
|
| 35 |
+
- Manual tags take priority (prepended to tag list)
|
| 36 |
+
- **Multi-language Support**: English and Japanese tag generation
|
| 37 |
+
|
| 38 |
+
### Evaluation & Analysis
|
| 39 |
+
- **Extended Metrics**: Precision@k, nDCG@k, MRR, Hit@k, Semantic Similarity, Latency (mean, p50, p90), User Satisfaction
|
| 40 |
+
- **Comparison Framework**: Side-by-side comparison of all retrieval methods
|
| 41 |
+
- **Visualization**: Bar charts, line plots, scatter plots, box plots, stacked bar charts, Pareto charts
|
| 42 |
+
- **Report Generation**: Comprehensive HTML, CSV, and JSON reports with embedded visualizations and representative examples
|
| 43 |
+
|
| 44 |
+
### Additional Features
|
| 45 |
+
- **Session Management**:
|
| 46 |
+
- Browser-based session persistence using localStorage
|
| 47 |
+
- Session ID automatically saved and restored on page refresh
|
| 48 |
+
- Multi-user support with isolated data and retrieval contexts
|
| 49 |
+
- Document count display (shows unique documents, not chunks)
|
| 50 |
+
- **Gradio UI**: User-friendly interface for all operations
|
| 51 |
+
- **MCP Server**: Model Context Protocol server for programmatic access
|
| 52 |
+
- **API Export**: All main functions exposed via Gradio Client API
|
| 53 |
+
|
| 54 |
+
## Repository Layout
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
.
|
| 58 |
+
├── app.py # Gradio UI entry point; defines interface and exposed functions
|
| 59 |
+
├── core/ # Core logic modules
|
| 60 |
+
│ ├── ingest.py # Document loaders, chunking, flat tag generation
|
| 61 |
+
│ ├── index.py # Embeddings, vector DB (ChromaDB), metadata/tag filtering
|
| 62 |
+
│ ├── retrieval.py # All RAG pipelines (Base, TagFilter, Hybrid, HybridRerank)
|
| 63 |
+
│ ├── eval.py # Evaluation metrics and batch evaluation
|
| 64 |
+
│ ├── tag_generator.py # Flat tag generation (YAKE, KeyBERT, spaCy, Janome, OpenAI)
|
| 65 |
+
│ ├── reranker.py # Reranking strategies (cross-encoder, semantic, heuristic)
|
| 66 |
+
│ ├── comparison.py # RAG method comparison framework
|
| 67 |
+
│ ├── visualization.py # Chart generation (bar, line, scatter, box, stacked, Pareto)
|
| 68 |
+
│ ├── report_generator.py # Comprehensive report generation (HTML, CSV, JSON)
|
| 69 |
+
│ ├── session_manager.py # User session management and isolation
|
| 70 |
+
│ ├── session_rag.py # Session-aware RAG manager
|
| 71 |
+
│ └── utils.py # Shared helpers (e.g., PII masking, ID generation)
|
| 72 |
+
├── tests/ # pytest test suite (E2E tests)
|
| 73 |
+
│ ├── conftest.py # Pytest fixtures
|
| 74 |
+
│ ├── test_accuracy.py # Accuracy tests (tag generation, retrieval, metrics)
|
| 75 |
+
│ ├── test_mcp_server.py # MCP server tests
|
| 76 |
+
│ ├── test_ux.py # User experience tests
|
| 77 |
+
│ ├── test_robustness.py # Edge cases and error handling
|
| 78 |
+
│ ├── test_user_scenarios.py # Non-technical user scenarios
|
| 79 |
+
│ └── test_japanese_support.py # Japanese language support tests
|
| 80 |
+
├── reports/ # Evaluation results (CSV/JSON/PNG/HTML)
|
| 81 |
+
├── requirements.txt # Python dependencies
|
| 82 |
+
└── README.md # This file
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Setup
|
| 86 |
+
|
| 87 |
+
### Prerequisites
|
| 88 |
+
|
| 89 |
+
- Python 3.8+
|
| 90 |
+
- pip or conda
|
| 91 |
+
|
| 92 |
+
### Installation
|
| 93 |
+
|
| 94 |
+
1. Clone the repository:
|
| 95 |
+
```bash
|
| 96 |
+
git clone <repository-url>
|
| 97 |
+
cd auto-tagging-rag
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
2. Install dependencies:
|
| 101 |
+
```bash
|
| 102 |
+
pip install -r requirements.txt
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
**Note:** Models (SentenceTransformers, spaCy) are downloaded automatically on first run. No manual download required.
|
| 106 |
+
|
| 107 |
+
3. Create necessary directories:
|
| 108 |
+
```bash
|
| 109 |
+
mkdir -p reports chroma_data
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### Environment Variables
|
| 113 |
+
|
| 114 |
+
**Security-Important Variables:**
|
| 115 |
+
|
| 116 |
+
Create a `.env` file (copy from `.env.example`) for sensitive configuration:
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
cp .env.example .env
|
| 120 |
+
# Edit .env and add your API key
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
- `OPENAI_API_KEY` (required only if using OpenAI features): Your OpenAI API key for embeddings, tag generation, and metadata detection
|
| 124 |
+
- Get your API key from: [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys)
|
| 125 |
+
- **Never commit `.env` to version control** (already in `.gitignore`)
|
| 126 |
+
|
| 127 |
+
**Note:** All other configuration variables have sensible defaults and can be overridden if needed. See `SETUP_GUIDE.md` for full configuration options.
|
| 128 |
+
|
| 129 |
+
## Deployment
|
| 130 |
+
|
| 131 |
+
### Hugging Face Spaces
|
| 132 |
+
|
| 133 |
+
Deploy this app to Hugging Face Spaces for free hosting:
|
| 134 |
+
|
| 135 |
+
1. **Create a Space**: Go to [Hugging Face Spaces](https://huggingface.co/spaces) and create a new Gradio Space
|
| 136 |
+
2. **Push Your Code**: Clone your Space repository and push your code
|
| 137 |
+
3. **Configure**: Set environment variables in Space Settings (optional)
|
| 138 |
+
4. **Deploy**: Spaces automatically builds and deploys your app
|
| 139 |
+
|
| 140 |
+
**See [HUGGINGFACE_DEPLOYMENT.md](HUGGINGFACE_DEPLOYMENT.md) for complete deployment guide.**
|
| 141 |
+
|
| 142 |
+
**Key Points for Spaces**:
|
| 143 |
+
- Models download automatically on first run (no manual download needed)
|
| 144 |
+
- Use `/tmp/chroma_data` for persistence (clears on restart)
|
| 145 |
+
- Environment variables configured in Space Settings
|
| 146 |
+
- Free CPU tier available, GPU tier for better performance
|
| 147 |
+
|
| 148 |
+
### Local Deployment
|
| 149 |
+
|
| 150 |
+
See [SETUP_GUIDE.md](SETUP_GUIDE.md) for local deployment instructions.
|
| 151 |
+
|
| 152 |
+
## Usage
|
| 153 |
+
|
| 154 |
+
### Using the Gradio API (gradio_client)
|
| 155 |
+
|
| 156 |
+
All main functions are exposed via the Gradio API with `api_name`:
|
| 157 |
+
|
| 158 |
+
- `build_rag`: Build RAG index from uploaded files
|
| 159 |
+
- `search`: Search documents using both pipelines
|
| 160 |
+
- `chat`: Chat interface with RAG system
|
| 161 |
+
- `evaluate`: Run quantitative evaluation
|
| 162 |
+
|
| 163 |
+
Example usage:
|
| 164 |
+
|
| 165 |
+
```python
|
| 166 |
+
from gradio_client import Client
|
| 167 |
+
|
| 168 |
+
client = Client("http://your-server:7860/")
|
| 169 |
+
|
| 170 |
+
# Build RAG index
|
| 171 |
+
result = client.predict(
|
| 172 |
+
files=["doc1.pdf", "doc2.pdf"],
|
| 173 |
+
language="en",
|
| 174 |
+
user_tags="hospital-protocol, urgent", # Optional manual tags
|
| 175 |
+
api_name="/build_rag"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Search documents
|
| 179 |
+
results = client.predict(
|
| 180 |
+
query="What are emergency procedures?",
|
| 181 |
+
k=5,
|
| 182 |
+
tags="emergency, procedures",
|
| 183 |
+
tag_operator="OR",
|
| 184 |
+
api_name="/search_all"
|
| 185 |
+
)
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
### MCP Server
|
| 189 |
+
|
| 190 |
+
The system can run as an MCP (Model Context Protocol) server for programmatic access:
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
python app.py --mcp
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
#### Connecting to MCP Server
|
| 197 |
+
|
| 198 |
+
Add to your MCP client configuration (e.g., for Claude Desktop):
|
| 199 |
+
|
| 200 |
+
```json
|
| 201 |
+
{
|
| 202 |
+
"mcpServers": {
|
| 203 |
+
"auto-tagging-rag": {
|
| 204 |
+
"command": "python",
|
| 205 |
+
"args": ["/path/to/app.py", "--mcp"],
|
| 206 |
+
"env": {}
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
#### Available MCP Tools
|
| 213 |
+
|
| 214 |
+
1. **search_documents**: Search documents using RAG system
|
| 215 |
+
- Parameters: `query`, `k`, `pipeline`, `tags`, `tag_operator`
|
| 216 |
+
|
| 217 |
+
2. **evaluate_retrieval**: Evaluate RAG performance with batch queries
|
| 218 |
+
- Parameters: `queries` (array), `output_file`
|
| 219 |
+
|
| 220 |
+
## UI Tabs
|
| 221 |
+
|
| 222 |
+
### 📁 Tab 1: Upload & Tagging
|
| 223 |
+
|
| 224 |
+
**Purpose**: Document processing and tag generation
|
| 225 |
+
|
| 226 |
+
**Components:**
|
| 227 |
+
- File upload area (PDF/TXT with drag & drop)
|
| 228 |
+
- **Language Selection**: Auto-detect, English (en), or Japanese (ja)
|
| 229 |
+
- **Manual Tag Input**: Add custom tags (comma-separated) that will be combined with auto-generated tags
|
| 230 |
+
- Example: `hospital-protocol, urgent, confidential`
|
| 231 |
+
- Manual tags are prepended to auto-generated tags
|
| 232 |
+
- Progress bar for processing status
|
| 233 |
+
- Processing Summary: Files processed, chunks created, tags generated, user tags count
|
| 234 |
+
- Indexed Chunks (preview): Shows chunks with tags and preview text
|
| 235 |
+
- Tag visualization display (cloud or list format)
|
| 236 |
+
- Reset Index button to clear all data
|
| 237 |
+
|
| 238 |
+
**After Build:**
|
| 239 |
+
- Build Status (processed count, indexed chunks, tags generated)
|
| 240 |
+
- File Summary (Filename, Chunks, Language, Tags, User Tags)
|
| 241 |
+
- Indexed Chunks (preview with metadata and first 160 chars)
|
| 242 |
+
- Tag Visualization (Top 20 tags with frequency)
|
| 243 |
+
|
| 244 |
+
### 🔍 Tab 2: Search & Compare
|
| 245 |
+
|
| 246 |
+
**Purpose**: Test different retrieval methods side-by-side
|
| 247 |
+
|
| 248 |
+
**Components:**
|
| 249 |
+
- Search query input box
|
| 250 |
+
- **Method Selection**: Compare all 4 retrieval methods simultaneously:
|
| 251 |
+
- Base RAG (vector similarity)
|
| 252 |
+
- Tag Filter RAG (tag-based filtering)
|
| 253 |
+
- Hybrid RAG (weighted combination)
|
| 254 |
+
- Hybrid Rerank RAG (with reranking)
|
| 255 |
+
- **Tags**: Comma-separated tags for tag-based filtering (optional)
|
| 256 |
+
- **Tag Operator**: AND/OR/NOT for tag filtering logic
|
| 257 |
+
- **Vector/Tag Weight Sliders**: Adjust hybrid retrieval weights (default 0.7/0.3)
|
| 258 |
+
- **k**: Number of results to retrieve (slider: 1-20)
|
| 259 |
+
- **Side-by-side Results Display**:
|
| 260 |
+
- Answers from each method
|
| 261 |
+
- Retrieved documents with scores
|
| 262 |
+
- Tags used in retrieval
|
| 263 |
+
- Tags found in results
|
| 264 |
+
- Response times
|
| 265 |
+
- Quick comparison metrics
|
| 266 |
+
|
| 267 |
+
### 💬 Tab 3: Chat Interface
|
| 268 |
+
|
| 269 |
+
**Purpose**: Natural conversation with tag-enhanced RAG
|
| 270 |
+
|
| 271 |
+
**Components:**
|
| 272 |
+
- Chat message input
|
| 273 |
+
- **Pipeline Selection**: Choose one retrieval method:
|
| 274 |
+
- Base RAG
|
| 275 |
+
- Tag Filter RAG
|
| 276 |
+
- Hybrid RAG
|
| 277 |
+
- Hybrid Rerank RAG
|
| 278 |
+
- **Tag Filter Toggle**: Enable/disable tag usage
|
| 279 |
+
- **Tag Input**: Specify tags for tag-based pipelines (comma-separated)
|
| 280 |
+
- Conversation history display
|
| 281 |
+
- Visible source documents with tags
|
| 282 |
+
- Adjust retrieval parameters (k)
|
| 283 |
+
|
| 284 |
+
### 📊 Tab 4: Analytics & Evaluation
|
| 285 |
+
|
| 286 |
+
**Purpose**: Performance visualization and metrics
|
| 287 |
+
|
| 288 |
+
**Components:**
|
| 289 |
+
- Input queries in JSON format with ground truth
|
| 290 |
+
- **User Satisfaction Scores**: Optional satisfaction ratings per query
|
| 291 |
+
- **Metrics Display**:
|
| 292 |
+
- Precision@k, nDCG@k, MRR, Hit@k
|
| 293 |
+
- Semantic Similarity
|
| 294 |
+
- Latency (mean, p50, p90)
|
| 295 |
+
- User Satisfaction (if provided)
|
| 296 |
+
- **Visualization Tabs**:
|
| 297 |
+
- Bar Charts: Method comparison metrics
|
| 298 |
+
- Line Plots: Metric trends over k values
|
| 299 |
+
- Scatter Plots: Correlation analysis
|
| 300 |
+
- Box Plots: Distribution analysis
|
| 301 |
+
- Stacked Bar Charts: Method breakdown
|
| 302 |
+
- Pareto Charts: Performance ranking
|
| 303 |
+
- **Summary Statistics**: Aggregated metrics across all queries
|
| 304 |
+
- **Export Buttons**: CSV/PNG export for results and charts
|
| 305 |
+
- **Report Generation**: Comprehensive HTML reports with embedded visualizations
|
| 306 |
+
|
| 307 |
+
### ⚙️ Tab 5: Settings & Management
|
| 308 |
+
|
| 309 |
+
**Purpose**: System configuration and user management
|
| 310 |
+
|
| 311 |
+
**Components:**
|
| 312 |
+
- **Tag Generation Parameters**:
|
| 313 |
+
- Max Tags Per Chunk (5-50)
|
| 314 |
+
- Min Tag Length (1-5 words)
|
| 315 |
+
- Max Tag Length (1-5 words)
|
| 316 |
+
- Tag Generation Method (auto, yake, keybert, spacy, janome, openai)
|
| 317 |
+
- **Hybrid Search Weights Configuration**:
|
| 318 |
+
- Default Vector Weight (0.0-1.0)
|
| 319 |
+
- Default Tag Weight (0.0-1.0)
|
| 320 |
+
- **Database Management**:
|
| 321 |
+
- Clear All Data
|
| 322 |
+
- Export Database
|
| 323 |
+
- Import Database
|
| 324 |
+
- **API Key Configuration**: OpenAI API key management
|
| 325 |
+
- **Embedding Configuration**: Select embedding provider (SentenceTransformers/OpenAI) and model
|
| 326 |
+
|
| 327 |
+
### 🎨 Global UI Elements
|
| 328 |
+
|
| 329 |
+
- **Session Indicator**: Shows current session ID (persisted in browser localStorage)
|
| 330 |
+
- **Document Count**: Shows total number of unique documents indexed (updates automatically)
|
| 331 |
+
- **Processing Status**: System status indicator
|
| 332 |
+
- **Session Persistence**: Session automatically restored from localStorage on page refresh
|
| 333 |
+
|
| 334 |
+
## Evaluation
|
| 335 |
+
|
| 336 |
+
### Quantitative Evaluation
|
| 337 |
+
|
| 338 |
+
The system evaluates all 4 retrieval pipelines (Base-RAG, Tag Filter RAG, Hybrid RAG, Hybrid Rerank RAG) on multiple metrics:
|
| 339 |
+
|
| 340 |
+
**Metrics:**
|
| 341 |
+
- **Precision@k**: Fraction of retrieved documents that are relevant
|
| 342 |
+
- **nDCG@k**: Normalized Discounted Cumulative Gain at k
|
| 343 |
+
- **MRR**: Mean Reciprocal Rank
|
| 344 |
+
- **Hit@k**: Whether at least one relevant document is in top-k
|
| 345 |
+
- **Semantic Similarity**: Average cosine similarity between query and retrieved documents
|
| 346 |
+
- **Latency**: Response time (mean, p50, p90 percentiles)
|
| 347 |
+
- **User Satisfaction**: Average satisfaction score (if provided)
|
| 348 |
+
|
| 349 |
+
### Evaluation Input Format
|
| 350 |
+
|
| 351 |
+
```json
|
| 352 |
+
[
|
| 353 |
+
{
|
| 354 |
+
"query": "What are emergency procedures?",
|
| 355 |
+
"ground_truth": ["chunk_id_1", "chunk_id_2", "chunk_id_3"],
|
| 356 |
+
"k_values": [1, 3, 5],
|
| 357 |
+
"tags": ["emergency", "procedure", "triage"],
|
| 358 |
+
"user_satisfaction": 4.5
|
| 359 |
+
}
|
| 360 |
+
]
|
| 361 |
+
```
|
| 362 |
+
|
| 363 |
+
**Note:** `ground_truth` should be a list of chunk IDs that are relevant to the query. The system will automatically retrieve chunk IDs from the indexed documents.
|
| 364 |
+
|
| 365 |
+
### Evaluation Results
|
| 366 |
+
|
| 367 |
+
Results are saved to `reports/` directory:
|
| 368 |
+
- **CSV**: Detailed metrics per query and pipeline
|
| 369 |
+
- **JSON Summary**: Aggregated statistics and metadata
|
| 370 |
+
- **Aggregated Stats CSV**: Summary by pipeline and k value
|
| 371 |
+
- **Examples JSON**: Representative examples (best/worst performing queries)
|
| 372 |
+
- **Visualizations**: PNG charts for all visualization types
|
| 373 |
+
- **HTML Report**: Comprehensive report with embedded visualizations
|
| 374 |
+
|
| 375 |
+
## Metadata Schema
|
| 376 |
+
|
| 377 |
+
Chunks are tagged with the following metadata:
|
| 378 |
+
```json
|
| 379 |
+
{
|
| 380 |
+
"doc_id": "uuid",
|
| 381 |
+
"chunk_id": "uuid",
|
| 382 |
+
"source_name": "filename.pdf",
|
| 383 |
+
"lang": "ja|en",
|
| 384 |
+
"tags": ["user-tag1", "user-tag2", "auto-tag1", "auto-tag2"],
|
| 385 |
+
"chunk_index": 0,
|
| 386 |
+
"chunk_size": 1000
|
| 387 |
+
}
|
| 388 |
+
```
|
| 389 |
+
|
| 390 |
+
**Tag Storage:**
|
| 391 |
+
- Tags are stored as comma-separated strings in ChromaDB metadata
|
| 392 |
+
- Converted back to lists when retrieved
|
| 393 |
+
- Manual tags (user-provided) are prepended to auto-generated tags
|
| 394 |
+
- Tags are normalized (lowercase, deduplicated)
|
| 395 |
+
|
| 396 |
+
**Document Counting:**
|
| 397 |
+
- System counts unique documents (by `doc_id`), not chunks
|
| 398 |
+
- Each document can have multiple chunks
|
| 399 |
+
- Document count displayed in UI header
|
| 400 |
+
|
| 401 |
+
## Documentation
|
| 402 |
+
|
| 403 |
+
### Setup & Configuration
|
| 404 |
+
- **[SETUP_GUIDE.md](SETUP_GUIDE.md)**: Complete setup instructions, environment variables, deployment options, and troubleshooting
|
| 405 |
+
- **[MODEL_DB_CONFIG.md](MODEL_DB_CONFIG.md)**: Detailed documentation of all models, libraries, APIs, and database configuration
|
| 406 |
+
- **[HUGGINGFACE_DEPLOYMENT.md](HUGGINGFACE_DEPLOYMENT.md)**: Guide for deploying to Hugging Face Spaces
|
| 407 |
+
|
| 408 |
+
### Testing
|
| 409 |
+
- **[E2E_TEST_DESIGN.md](E2E_TEST_DESIGN.md)**: Comprehensive E2E test cases covering accuracy, UX, robustness, and non-technical user scenarios (46 test cases)
|
| 410 |
+
|
| 411 |
+
### Usage Guides
|
| 412 |
+
- **[EVALUATION_GUIDE.md](EVALUATION_GUIDE.md)**: How to run evaluation with sample queries
|
| 413 |
+
|
| 414 |
+
## Testing
|
| 415 |
+
|
| 416 |
+
See **[E2E_TEST_DESIGN.md](E2E_TEST_DESIGN.md)** for comprehensive test design document with 46 test cases covering:
|
| 417 |
+
|
| 418 |
+
**Test Categories**:
|
| 419 |
+
- **Accuracy Tests** (13 cases): Tag generation, retrieval accuracy, evaluation metrics
|
| 420 |
+
- **User Experience Tests** (12 cases): Non-technical user scenarios, advanced usage, UI/UX quality
|
| 421 |
+
- **Robustness Tests** (13 cases): Error handling, edge cases, data integrity
|
| 422 |
+
- **Performance Tests** (5 cases): Response time, scalability
|
| 423 |
+
- **Integration Tests** (3 cases): API integration, end-to-end workflows
|
| 424 |
+
|
| 425 |
+
**Test Execution**:
|
| 426 |
+
- Manual tests: Follow step-by-step instructions in test design document
|
| 427 |
+
- Automated tests: Run pytest (if implemented)
|
| 428 |
+
```bash
|
| 429 |
+
pytest tests/ -v
|
| 430 |
+
```
|
| 431 |
+
|
| 432 |
+
## Architecture
|
| 433 |
+
|
| 434 |
+
### Retrieval Pipelines
|
| 435 |
+
|
| 436 |
+
**1. Base-RAG:**
|
| 437 |
+
- Vector similarity search using embeddings
|
| 438 |
+
- Return top-k results ranked by cosine similarity
|
| 439 |
+
- No filtering applied
|
| 440 |
+
|
| 441 |
+
**2. Tag Filter RAG:**
|
| 442 |
+
- Filter documents using flat tags with boolean operators (AND/OR/NOT)
|
| 443 |
+
- Vector search within filtered subset
|
| 444 |
+
- Return top-k results with tag context
|
| 445 |
+
|
| 446 |
+
**3. Hybrid RAG:**
|
| 447 |
+
- Perform both vector search and tag-based filtering
|
| 448 |
+
- Combine scores with configurable weights (default: 0.7 vector, 0.3 tag)
|
| 449 |
+
- Normalize and merge results
|
| 450 |
+
- Return top-k results ranked by hybrid score
|
| 451 |
+
|
| 452 |
+
**4. Hybrid Rerank RAG:**
|
| 453 |
+
- Perform hybrid retrieval (steps 1-3 above)
|
| 454 |
+
- Apply reranking using cross-encoder or semantic similarity
|
| 455 |
+
- Return top-k reranked results
|
| 456 |
+
|
| 457 |
+
### Tag-Based System Architecture
|
| 458 |
+
|
| 459 |
+
**Tag Generation:**
|
| 460 |
+
1. **Language Detection**: Auto-detect document language (English/Japanese)
|
| 461 |
+
2. **Manual Tag Input**: User can add custom tags during upload (optional)
|
| 462 |
+
3. **Method Selection**: Choose tag generation method based on availability and language
|
| 463 |
+
4. **Tag Extraction**: Extract tags using selected method (YAKE, KeyBERT, spaCy, Janome, or OpenAI)
|
| 464 |
+
5. **Tag Merging**: Combine manual tags with auto-generated tags
|
| 465 |
+
- Manual tags are prepended (higher priority)
|
| 466 |
+
- Deduplicate identical tags
|
| 467 |
+
6. **Tag Filtering**: Remove stopwords, normalize, deduplicate
|
| 468 |
+
7. **Storage**: Store tags in document metadata (comma-separated string in ChromaDB)
|
| 469 |
+
|
| 470 |
+
**Tag Filtering:**
|
| 471 |
+
- **OR Operator**: Documents matching any tag (union)
|
| 472 |
+
- **AND Operator**: Documents matching all tags (intersection)
|
| 473 |
+
- **NOT Operator**: Exclude documents matching specified tags
|
| 474 |
+
- **Post-filtering**: For AND/NOT operators, filter results after initial retrieval
|
| 475 |
+
|
| 476 |
+
### Reranking Strategies
|
| 477 |
+
|
| 478 |
+
1. **Cross-Encoder**: Use transformer-based cross-encoder models (e.g., `cross-encoder/ms-marco-MiniLM-L-6-v2`)
|
| 479 |
+
2. **OpenAI**: Use OpenAI API for semantic reranking (if available)
|
| 480 |
+
3. **Semantic Similarity**: Re-rank by query-document semantic similarity
|
| 481 |
+
4. **Heuristic**: Simple score-based reranking
|
| 482 |
+
|
| 483 |
+
## Vector Database & Embeddings
|
| 484 |
+
|
| 485 |
+
- **ChromaDB** with persistence
|
| 486 |
+
- **Embeddings Provider**:
|
| 487 |
+
- OpenAI (if `OPENAI_API_KEY` present): `OPENAI_EMBED_MODEL` (default `text-embedding-3-small`)
|
| 488 |
+
- SentenceTransformers fallback: `ST_EMBED_MODEL` (default `all-MiniLM-L6-v2`)
|
| 489 |
+
- **Collections**: Namespaced by provider/dimension and session to avoid mismatch
|
| 490 |
+
- Format: `session_{session_id[:8]}` for user sessions
|
| 491 |
+
- Each session has isolated ChromaDB collections
|
| 492 |
+
- **Tag Filtering**: Supported for flat tags with OR/AND/NOT operators
|
| 493 |
+
- **Document Counting**: Counts unique documents (by `doc_id`), not chunks
|
| 494 |
+
- **Session Isolation**: Each user session has isolated ChromaDB collections
|
| 495 |
+
|
| 496 |
+
## Session Management
|
| 497 |
+
|
| 498 |
+
The system supports multi-user deployments with browser-based session persistence:
|
| 499 |
+
|
| 500 |
+
- **Session Creation**: Automatic session creation on first page load
|
| 501 |
+
- **Session Persistence**:
|
| 502 |
+
- Session ID automatically saved to browser localStorage
|
| 503 |
+
- Session restored automatically on page refresh
|
| 504 |
+
- Data persists across browser sessions and refreshes
|
| 505 |
+
- **Session Isolation**: Each session has its own ChromaDB collection
|
| 506 |
+
- Collection name format: `session_{session_id[:8]}`
|
| 507 |
+
- All documents, tags, and metadata are isolated per session
|
| 508 |
+
- **Session Display**:
|
| 509 |
+
- Session ID shown in UI header (first 8 characters)
|
| 510 |
+
- Document count displayed and updated automatically
|
| 511 |
+
- **Session Timeout**: Configurable timeout (default 3600 seconds)
|
| 512 |
+
- **Session Cleanup**: Automatic cleanup of expired sessions
|
| 513 |
+
- **Session-Aware RAG**: All RAG operations are scoped to the user's session
|
| 514 |
+
- **New Session Creation**: If session ID exists in localStorage but not on server, creates new session and updates localStorage
|
| 515 |
+
|
| 516 |
+
## Deployment
|
| 517 |
+
|
| 518 |
+
### Local Deployment
|
| 519 |
+
|
| 520 |
+
```bash
|
| 521 |
+
python app.py
|
| 522 |
+
```
|
| 523 |
+
|
| 524 |
+
The app will be available at `http://localhost:7860`
|
| 525 |
+
|
| 526 |
+
### Hugging Face Spaces
|
| 527 |
+
|
| 528 |
+
The app is configured for Hugging Face Spaces deployment. The `README.md` frontmatter includes Spaces-specific configuration.
|
| 529 |
+
|
| 530 |
+
### Docker Deployment
|
| 531 |
+
|
| 532 |
+
```dockerfile
|
| 533 |
+
FROM python:3.9-slim
|
| 534 |
+
|
| 535 |
+
WORKDIR /app
|
| 536 |
+
COPY requirements.txt .
|
| 537 |
+
RUN pip install -r requirements.txt
|
| 538 |
+
|
| 539 |
+
COPY . .
|
| 540 |
+
|
| 541 |
+
CMD ["python", "app.py"]
|
| 542 |
+
```
|
| 543 |
+
|
| 544 |
+
### Environment Configuration
|
| 545 |
+
|
| 546 |
+
For production deployments, ensure:
|
| 547 |
+
- ChromaDB persistence directory is writable
|
| 548 |
+
- OpenAI API key is set (if using OpenAI features)
|
| 549 |
+
- Required model files are downloaded (spaCy, SentenceTransformers)
|
| 550 |
+
- Sufficient disk space for vector database and reports
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
## Dependencies
|
| 554 |
+
|
| 555 |
+
All dependencies are included in `requirements.txt`. Key dependencies include:
|
| 556 |
+
|
| 557 |
+
### Core Dependencies
|
| 558 |
+
- `gradio>=5.49.1`: Web UI framework
|
| 559 |
+
- `chromadb>=0.4.0`: Vector database
|
| 560 |
+
- `sentence-transformers>=2.2.0`: Embeddings
|
| 561 |
+
- `pypdf>=3.0.0`: PDF processing
|
| 562 |
+
- `pandas>=1.5.0`: Data processing
|
| 563 |
+
- `numpy>=1.21.0`: Numerical operations
|
| 564 |
+
- `scikit-learn>=1.2.0`: Machine learning utilities
|
| 565 |
+
|
| 566 |
+
### Tag Generation (English and Japanese)
|
| 567 |
+
- `yake>=0.4.0`: Keyword extraction
|
| 568 |
+
- `keybert>=0.8.0`: BERT-based keyword extraction
|
| 569 |
+
- `spacy>=3.7.0`: NLP processing (English) - model `en_core_web_sm` required
|
| 570 |
+
- `janome>=0.5.0`: Japanese tokenization
|
| 571 |
+
|
| 572 |
+
### Visualization and Reports
|
| 573 |
+
- `matplotlib>=3.5.0`: Chart generation
|
| 574 |
+
- `jinja2>=3.1.0`: HTML template rendering
|
| 575 |
+
|
| 576 |
+
### Optional
|
| 577 |
+
- `openai>=1.0.0`: OpenAI API client (for OpenAI embeddings, tag generation, and reranking)
|
| 578 |
+
|
| 579 |
+
See `requirements.txt` for the complete list.
|
| 580 |
+
|
| 581 |
+
## Troubleshooting
|
| 582 |
+
|
| 583 |
+
### Tag Generation Not Working
|
| 584 |
+
- Ensure all dependencies are installed: `pip install -r requirements.txt`
|
| 585 |
+
- For English: Download spaCy model: `python -m spacy download en_core_web_sm`
|
| 586 |
+
- For Japanese: Ensure `janome` is installed (included in requirements.txt)
|
| 587 |
+
- Verify documents are correctly detected as the appropriate language
|
| 588 |
+
|
| 589 |
+
### Visualization Errors
|
| 590 |
+
- Ensure all dependencies are installed: `pip install -r requirements.txt` (includes `matplotlib`)
|
| 591 |
+
- Check that `reports/` directory is writable
|
| 592 |
+
|
| 593 |
+
### Reranking Errors
|
| 594 |
+
- Cross-encoder models are downloaded automatically on first use
|
| 595 |
+
- For OpenAI reranking, ensure `OPENAI_API_KEY` is set
|
| 596 |
+
- Check internet connection for model downloads
|
| 597 |
+
|
| 598 |
+
### Session Issues
|
| 599 |
+
- Session ID is stored in browser localStorage with key `rag_session_id`
|
| 600 |
+
- Clear browser localStorage to reset session
|
| 601 |
+
- Ensure ChromaDB persistence directory is writable
|
| 602 |
+
- Check session timeout settings
|
| 603 |
+
- Verify session cleanup is running
|
| 604 |
+
- If session ID exists in localStorage but data is missing, a new session will be created automatically
|
| 605 |
+
|
| 606 |
+
### Japanese Support Issues
|
| 607 |
+
- Install `janome`: `pip install janome`
|
| 608 |
+
- Ensure documents are correctly detected as Japanese
|
| 609 |
+
- Check that Japanese tag generation method is selected
|
| 610 |
+
|
| 611 |
+
## Acknowledgments
|
| 612 |
+
|
| 613 |
+
Built for comprehensive RAG evaluation and comparison, supporting multiple retrieval strategies, tag-based filtering, hybrid approaches, and extensive evaluation metrics.
|
SETUP_GUIDE.md
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Setup Guide - Auto Tagging RAG System
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This guide provides detailed setup instructions for the Auto Tagging RAG System, including environment configuration, dependency installation, deployment options, and troubleshooting.
|
| 6 |
+
|
| 7 |
+
## Table of Contents
|
| 8 |
+
|
| 9 |
+
1. [Prerequisites](#prerequisites)
|
| 10 |
+
2. [Installation](#installation)
|
| 11 |
+
3. [Environment Configuration](#environment-configuration)
|
| 12 |
+
4. [Deployment Options](#deployment-options)
|
| 13 |
+
5. [Initial Setup](#initial-setup)
|
| 14 |
+
6. [Verification](#verification)
|
| 15 |
+
7. [Troubleshooting](#troubleshooting)
|
| 16 |
+
|
| 17 |
+
## Prerequisites
|
| 18 |
+
|
| 19 |
+
### System Requirements
|
| 20 |
+
|
| 21 |
+
- **Python**: 3.8 or higher
|
| 22 |
+
- **Operating System**: Linux, macOS, or Windows
|
| 23 |
+
- **Memory**: Minimum 4GB RAM (8GB+ recommended)
|
| 24 |
+
- **Storage**: At least 2GB free space for models and data
|
| 25 |
+
- **Network**: Internet connection for initial model downloads (offline mode supported after setup)
|
| 26 |
+
|
| 27 |
+
### Software Dependencies
|
| 28 |
+
|
| 29 |
+
- Python package manager: `pip` or `conda`
|
| 30 |
+
- Git (for cloning repository)
|
| 31 |
+
|
| 32 |
+
## Installation
|
| 33 |
+
|
| 34 |
+
### Step 1: Clone Repository
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
git clone <repository-url>
|
| 38 |
+
cd auto_tagging_rag
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Step 2: Create Virtual Environment (Recommended)
|
| 42 |
+
|
| 43 |
+
**Using venv:**
|
| 44 |
+
```bash
|
| 45 |
+
python3 -m venv venv
|
| 46 |
+
source venv/bin/activate # On Linux/macOS
|
| 47 |
+
# or
|
| 48 |
+
venv\Scripts\activate # On Windows
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
**Using conda:**
|
| 52 |
+
```bash
|
| 53 |
+
conda create -n auto_tagging_rag python=3.10
|
| 54 |
+
conda activate auto_tagging_rag
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Step 3: Install Dependencies
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
pip install -r requirements.txt
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
**Important**: Wait for all packages to install completely before proceeding to Step 4.
|
| 64 |
+
|
| 65 |
+
### Step 4: Create Required Directories
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
mkdir -p reports chroma_data
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
These directories will store:
|
| 72 |
+
- `reports/`: Evaluation results, visualizations, and generated reports
|
| 73 |
+
- `chroma_data/`: Vector database persistent storage
|
| 74 |
+
|
| 75 |
+
### Step 5: Model Downloads (Automatic)
|
| 76 |
+
|
| 77 |
+
The first time you run the application, it will automatically download:
|
| 78 |
+
- **SentenceTransformers**: `all-MiniLM-L6-v2` embedding model (~80MB)
|
| 79 |
+
- **spaCy**: `en_core_web_sm` model (if using spaCy tag generation)
|
| 80 |
+
|
| 81 |
+
**Important**:
|
| 82 |
+
- Models download automatically on first use
|
| 83 |
+
- No manual download required
|
| 84 |
+
- Ensure internet connection for initial downloads
|
| 85 |
+
- Models are cached locally for offline operation
|
| 86 |
+
|
| 87 |
+
## Environment Configuration
|
| 88 |
+
|
| 89 |
+
### Creating `.env` File
|
| 90 |
+
|
| 91 |
+
Create a `.env` file in the project root:
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
touch .env # Linux/macOS
|
| 95 |
+
# or create .env file manually on Windows
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### Environment Variables
|
| 99 |
+
|
| 100 |
+
Copy the template below and configure as needed:
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
# ============================================
|
| 104 |
+
# OpenAI Configuration (Optional)
|
| 105 |
+
# ============================================
|
| 106 |
+
# Required only if using OpenAI embeddings/tag generation
|
| 107 |
+
OPENAI_API_KEY=your_openai_api_key_here
|
| 108 |
+
|
| 109 |
+
# OpenAI Model Selection
|
| 110 |
+
OPENAI_MODEL=gpt-4o-mini # Chat model for tag generation
|
| 111 |
+
OPENAI_EMBED_MODEL=text-embedding-3-small # Embedding model (1536 dimensions)
|
| 112 |
+
|
| 113 |
+
# Force OpenAI embeddings (true/false)
|
| 114 |
+
USE_OPENAI_EMBEDDINGS=false # Set to true to force OpenAI embeddings
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ============================================
|
| 118 |
+
# Embedding Configuration
|
| 119 |
+
# ============================================
|
| 120 |
+
# SentenceTransformers fallback model
|
| 121 |
+
ST_EMBED_MODEL=all-MiniLM-L6-v2 # Used when OpenAI is not available
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ============================================
|
| 125 |
+
# Reranking Configuration
|
| 126 |
+
# ============================================
|
| 127 |
+
RERANKER_MODEL=cross-encoder/ms-marco-MiniLM-L-6-v2
|
| 128 |
+
RERANKER_STRATEGY=cross_encoder # Options: cross_encoder, openai, semantic, heuristic
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ============================================
|
| 132 |
+
# Tag Generation Configuration
|
| 133 |
+
# ============================================
|
| 134 |
+
TAG_GENERATION_METHOD=auto # Options: yake, keybert, spacy, janome, openai, auto
|
| 135 |
+
MAX_TAGS_PER_CHUNK=10 # Maximum tags per document chunk
|
| 136 |
+
MIN_TAG_LENGTH=2 # Minimum tag length (words)
|
| 137 |
+
MAX_TAG_LENGTH=5 # Maximum tag length (words)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ============================================
|
| 141 |
+
# Database & Storage Configuration
|
| 142 |
+
# ============================================
|
| 143 |
+
CHROMA_PERSIST_DIR=./chroma_data # ChromaDB persistence directory
|
| 144 |
+
SESSION_TIMEOUT=3600 # Session timeout in seconds (default: 1 hour)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ============================================
|
| 148 |
+
# UI & Behavior Configuration
|
| 149 |
+
# ============================================
|
| 150 |
+
DEFAULT_SEARCH_K=5 # Default number of search results
|
| 151 |
+
GRADIO_SERVER_PORT=7860 # Gradio server port
|
| 152 |
+
GRADIO_SERVER_NAME=0.0.0.0 # Server host (0.0.0.0 for all interfaces)
|
| 153 |
+
LOG_LEVEL=INFO # Logging level: DEBUG, INFO, WARNING, ERROR
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
### Environment Variable Descriptions
|
| 157 |
+
|
| 158 |
+
#### OpenAI Configuration
|
| 159 |
+
- **OPENAI_API_KEY**: Your OpenAI API key (get from https://platform.openai.com/api-keys)
|
| 160 |
+
- Required only if using OpenAI embeddings or tag generation
|
| 161 |
+
- Can be omitted for fully offline operation
|
| 162 |
+
- **OPENAI_MODEL**: Chat model for tag generation (default: `gpt-4o-mini`)
|
| 163 |
+
- **OPENAI_EMBED_MODEL**: Embedding model (default: `text-embedding-3-small`, 1536 dimensions)
|
| 164 |
+
- **USE_OPENAI_EMBEDDINGS**: Force OpenAI embeddings even if API key exists (default: `false`)
|
| 165 |
+
|
| 166 |
+
#### Embedding Configuration
|
| 167 |
+
- **ST_EMBED_MODEL**: SentenceTransformers model (default: `all-MiniLM-L6-v2`)
|
| 168 |
+
- Automatically downloaded on first run
|
| 169 |
+
- Used when OpenAI is not available or disabled
|
| 170 |
+
|
| 171 |
+
#### Reranking Configuration
|
| 172 |
+
- **RERANKER_MODEL**: Cross-encoder model for reranking (default: `cross-encoder/ms-marco-MiniLM-L-6-v2`)
|
| 173 |
+
- Automatically downloaded on first use
|
| 174 |
+
- **RERANKER_STRATEGY**: Reranking method (default: `cross_encoder`)
|
| 175 |
+
- Options: `cross_encoder`, `openai`, `semantic`, `heuristic`
|
| 176 |
+
|
| 177 |
+
#### Tag Generation Configuration
|
| 178 |
+
- **TAG_GENERATION_METHOD**: Tag generation method (default: `auto`)
|
| 179 |
+
- `auto`: Automatically selects best available method
|
| 180 |
+
- `yake`: YAKE keyword extraction (English/Japanese)
|
| 181 |
+
- `keybert`: KeyBERT with BERT embeddings
|
| 182 |
+
- `spacy`: spaCy noun phrase extraction (English only)
|
| 183 |
+
- `janome`: Japanese tokenization and noun extraction
|
| 184 |
+
- `openai`: OpenAI-based tag generation (requires API key)
|
| 185 |
+
- **MAX_TAGS_PER_CHUNK**: Maximum tags generated per chunk (default: `10`)
|
| 186 |
+
- **MIN_TAG_LENGTH**: Minimum tag length in words (default: `2`)
|
| 187 |
+
- **MAX_TAG_LENGTH**: Maximum tag length in words (default: `5`)
|
| 188 |
+
|
| 189 |
+
#### Database & Storage
|
| 190 |
+
- **CHROMA_PERSIST_DIR**: Directory for ChromaDB storage (default: `./chroma_data`)
|
| 191 |
+
- **SESSION_TIMEOUT**: Session expiration time in seconds (default: `3600` = 1 hour)
|
| 192 |
+
|
| 193 |
+
#### UI & Behavior
|
| 194 |
+
- **DEFAULT_SEARCH_K**: Default number of search results (default: `5`)
|
| 195 |
+
- **GRADIO_SERVER_PORT**: Server port (default: `7860`)
|
| 196 |
+
- **GRADIO_SERVER_NAME**: Server hostname (default: `0.0.0.0` for all interfaces)
|
| 197 |
+
- **LOG_LEVEL**: Logging verbosity (default: `INFO`)
|
| 198 |
+
|
| 199 |
+
### Minimal Configuration Example
|
| 200 |
+
|
| 201 |
+
For offline operation without OpenAI:
|
| 202 |
+
|
| 203 |
+
```bash
|
| 204 |
+
# Minimal .env file for offline operation
|
| 205 |
+
LOG_LEVEL=INFO
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
The system will use default values for all other settings.
|
| 209 |
+
|
| 210 |
+
## Deployment Options
|
| 211 |
+
|
| 212 |
+
### Option 1: Local Development
|
| 213 |
+
|
| 214 |
+
**Run locally:**
|
| 215 |
+
```bash
|
| 216 |
+
python app.py
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
Access at: `http://localhost:7860`
|
| 220 |
+
|
| 221 |
+
### Option 2: Hugging Face Spaces
|
| 222 |
+
|
| 223 |
+
**Requirements**:
|
| 224 |
+
- Hugging Face account
|
| 225 |
+
- Space created (set SDK to `gradio`)
|
| 226 |
+
|
| 227 |
+
**Deployment Steps**:
|
| 228 |
+
|
| 229 |
+
1. **Push to Hugging Face**:
|
| 230 |
+
```bash
|
| 231 |
+
git remote add huggingface https://huggingface.co/spaces/<your-username>/<space-name>
|
| 232 |
+
git push huggingface main
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
2. **Configure Space Settings**:
|
| 236 |
+
- **SDK**: `gradio`
|
| 237 |
+
- **Python Version**: `3.10`
|
| 238 |
+
- **Hardware**: CPU (or GPU if needed)
|
| 239 |
+
|
| 240 |
+
3. **Add Secrets** (if using OpenAI):
|
| 241 |
+
- Go to Space Settings → Secrets
|
| 242 |
+
- Add `OPENAI_API_KEY` secret
|
| 243 |
+
|
| 244 |
+
4. **Environment Variables**:
|
| 245 |
+
- Add via Space Settings → Variables tab
|
| 246 |
+
- Or use secrets for sensitive values
|
| 247 |
+
|
| 248 |
+
**Note**: Hugging Face Spaces automatically installs from `requirements.txt` and runs `app.py`.
|
| 249 |
+
|
| 250 |
+
### Option 3: Docker Deployment
|
| 251 |
+
|
| 252 |
+
**Create Dockerfile**:
|
| 253 |
+
```dockerfile
|
| 254 |
+
FROM python:3.10-slim
|
| 255 |
+
|
| 256 |
+
WORKDIR /app
|
| 257 |
+
|
| 258 |
+
# Install system dependencies
|
| 259 |
+
RUN apt-get update && apt-get install -y \
|
| 260 |
+
build-essential \
|
| 261 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 262 |
+
|
| 263 |
+
# Copy requirements
|
| 264 |
+
COPY requirements.txt .
|
| 265 |
+
|
| 266 |
+
# Install Python dependencies
|
| 267 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 268 |
+
|
| 269 |
+
# Models download automatically on first run
|
| 270 |
+
|
| 271 |
+
# Copy application
|
| 272 |
+
COPY . .
|
| 273 |
+
|
| 274 |
+
# Create directories
|
| 275 |
+
RUN mkdir -p reports chroma_data
|
| 276 |
+
|
| 277 |
+
# Expose port
|
| 278 |
+
EXPOSE 7860
|
| 279 |
+
|
| 280 |
+
# Run application
|
| 281 |
+
CMD ["python", "app.py"]
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
**Build and Run**:
|
| 285 |
+
```bash
|
| 286 |
+
docker build -t auto-tagging-rag .
|
| 287 |
+
docker run -p 7860:7860 --env-file .env auto-tagging-rag
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
### Option 4: Cloud Deployment (AWS/GCP/Azure)
|
| 291 |
+
|
| 292 |
+
**For AWS EC2**:
|
| 293 |
+
```bash
|
| 294 |
+
# Install dependencies
|
| 295 |
+
sudo apt-get update
|
| 296 |
+
sudo apt-get install python3-pip python3-venv git
|
| 297 |
+
|
| 298 |
+
# Clone repository
|
| 299 |
+
git clone <repository-url>
|
| 300 |
+
cd auto_tagging_rag
|
| 301 |
+
|
| 302 |
+
# Setup virtual environment
|
| 303 |
+
python3 -m venv venv
|
| 304 |
+
source venv/bin/activate
|
| 305 |
+
|
| 306 |
+
# Install dependencies
|
| 307 |
+
pip install -r requirements.txt
|
| 308 |
+
# Models download automatically on first run
|
| 309 |
+
|
| 310 |
+
# Configure environment
|
| 311 |
+
cp .env.example .env
|
| 312 |
+
# Edit .env with your settings
|
| 313 |
+
|
| 314 |
+
# Run with nohup or systemd
|
| 315 |
+
nohup python app.py > app.log 2>&1 &
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
**Security Groups**: Open port 7860 (or your configured port) for HTTP access.
|
| 319 |
+
|
| 320 |
+
## Initial Setup
|
| 321 |
+
|
| 322 |
+
### First Run Checklist
|
| 323 |
+
|
| 324 |
+
1. ✅ Python 3.8+ installed
|
| 325 |
+
2. ✅ Dependencies installed (`pip install -r requirements.txt`)
|
| 326 |
+
3. ✅ `.env` file created and configured (optional, only if using OpenAI)
|
| 327 |
+
4. ✅ Directories created (`reports/`, `chroma_data/`)
|
| 328 |
+
5. ✅ Internet connection available (for initial model downloads - models download automatically)
|
| 329 |
+
|
| 330 |
+
### First Launch
|
| 331 |
+
|
| 332 |
+
```bash
|
| 333 |
+
python app.py
|
| 334 |
+
```
|
| 335 |
+
|
| 336 |
+
**Expected Output**:
|
| 337 |
+
```
|
| 338 |
+
INFO:rag_app:OpenAI API key detected. OpenAI-powered auto-detection is ENABLED.
|
| 339 |
+
# or
|
| 340 |
+
INFO:rag_app:OpenAI API key not found. Using SentenceTransformers embeddings.
|
| 341 |
+
INFO:chromadb.telemetry.product.posthog:Anonymized telemetry enabled...
|
| 342 |
+
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2
|
| 343 |
+
Running on local URL: http://127.0.0.1:7860
|
| 344 |
+
```
|
| 345 |
+
|
| 346 |
+
**First-time downloads**:
|
| 347 |
+
- SentenceTransformers model (~80MB) - downloads automatically
|
| 348 |
+
- Reranker model (~40MB) - downloads on first use if reranking enabled
|
| 349 |
+
|
| 350 |
+
### Offline Operation
|
| 351 |
+
|
| 352 |
+
After initial setup, the system works **fully offline**:
|
| 353 |
+
- All models are cached locally
|
| 354 |
+
- No network access required
|
| 355 |
+
- Tiktoken fallback handles network issues gracefully
|
| 356 |
+
|
| 357 |
+
## Verification
|
| 358 |
+
|
| 359 |
+
### Test Installation
|
| 360 |
+
|
| 361 |
+
1. **Start Application**:
|
| 362 |
+
```bash
|
| 363 |
+
python app.py
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
2. **Access UI**: Open `http://localhost:7860` in browser
|
| 367 |
+
|
| 368 |
+
3. **Test Upload**:
|
| 369 |
+
- Go to "Upload & Tagging" tab
|
| 370 |
+
- Upload `sample_documents/emergency_procedures.txt`
|
| 371 |
+
- Click "Build RAG Index"
|
| 372 |
+
- Verify: Tags generated, document count = 1
|
| 373 |
+
|
| 374 |
+
4. **Test Search**:
|
| 375 |
+
- Go to "Search & Compare" tab
|
| 376 |
+
- Enter query: "What are emergency procedures?"
|
| 377 |
+
- Click "Search All Methods"
|
| 378 |
+
- Verify: Results appear for all 4 methods
|
| 379 |
+
|
| 380 |
+
5. **Test Evaluation**:
|
| 381 |
+
- Go to "Analytics & Evaluation" tab
|
| 382 |
+
- Copy content from `sample_evaluation_queries.json`
|
| 383 |
+
- Paste into "Evaluation Queries" field
|
| 384 |
+
- Click "Run Evaluation"
|
| 385 |
+
- Verify: Charts appear, CSV generated
|
| 386 |
+
|
| 387 |
+
### Verify Models
|
| 388 |
+
|
| 389 |
+
**Check spaCy model**:
|
| 390 |
+
```bash
|
| 391 |
+
python3 -c "import spacy; nlp = spacy.load('en_core_web_sm'); print('✓ spaCy model loaded')"
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
**Check SentenceTransformers**:
|
| 395 |
+
```bash
|
| 396 |
+
python3 -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2'); print('✓ SentenceTransformers model loaded')"
|
| 397 |
+
```
|
| 398 |
+
|
| 399 |
+
**Check ChromaDB**:
|
| 400 |
+
```bash
|
| 401 |
+
python3 -c "import chromadb; client = chromadb.PersistentClient(path='./chroma_data'); print('✓ ChromaDB initialized')"
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
## Troubleshooting
|
| 405 |
+
|
| 406 |
+
### Common Issues
|
| 407 |
+
|
| 408 |
+
#### Issue 1: "No module named spacy"
|
| 409 |
+
|
| 410 |
+
**Solution**:
|
| 411 |
+
```bash
|
| 412 |
+
pip install spacy>=3.7.0
|
| 413 |
+
python -m spacy download en_core_web_sm
|
| 414 |
+
```
|
| 415 |
+
|
| 416 |
+
#### Issue 2: "spaCy model not found"
|
| 417 |
+
|
| 418 |
+
**Solution**:
|
| 419 |
+
```bash
|
| 420 |
+
python -m spacy download en_core_web_sm
|
| 421 |
+
# Verify
|
| 422 |
+
python3 -c "import spacy; nlp = spacy.load('en_core_web_sm')"
|
| 423 |
+
```
|
| 424 |
+
|
| 425 |
+
#### Issue 3: "Permission denied: /data"
|
| 426 |
+
|
| 427 |
+
**Solution**: ChromaDB defaults to `/data` which may not be writable. The app automatically falls back to `./chroma_data` or user cache directory.
|
| 428 |
+
|
| 429 |
+
**Manual override**:
|
| 430 |
+
```bash
|
| 431 |
+
export CHROMA_PERSIST_DIR=./chroma_data
|
| 432 |
+
python app.py
|
| 433 |
+
```
|
| 434 |
+
|
| 435 |
+
#### Issue 4: "Tiktoken network warning"
|
| 436 |
+
|
| 437 |
+
**Explanation**: This is expected in offline mode. Tiktoken attempts to download encodings but falls back gracefully to character-based token counting.
|
| 438 |
+
|
| 439 |
+
**Solution**: No action needed - the system handles this automatically.
|
| 440 |
+
|
| 441 |
+
#### Issue 5: "OpenAI API key not working"
|
| 442 |
+
|
| 443 |
+
**Solution**:
|
| 444 |
+
1. Verify key in `.env` file: `OPENAI_API_KEY=sk-...`
|
| 445 |
+
2. Check key is valid: https://platform.openai.com/api-keys
|
| 446 |
+
3. Restart application after changing `.env`
|
| 447 |
+
|
| 448 |
+
#### Issue 6: "ChromaDB collection not found after restart"
|
| 449 |
+
|
| 450 |
+
**Solution**: Collections are namespaced by embedding provider/dimension. If you switch providers, you need to re-upload documents.
|
| 451 |
+
|
| 452 |
+
**Check collections**:
|
| 453 |
+
```python
|
| 454 |
+
import chromadb
|
| 455 |
+
client = chromadb.PersistentClient(path="./chroma_data")
|
| 456 |
+
collections = client.list_collections()
|
| 457 |
+
print([col.name for col in collections])
|
| 458 |
+
```
|
| 459 |
+
|
| 460 |
+
#### Issue 7: "Session not persisting after refresh"
|
| 461 |
+
|
| 462 |
+
**Solution**:
|
| 463 |
+
- Sessions are stored in browser `localStorage` automatically
|
| 464 |
+
- Server restarts: Sessions are restored from ChromaDB collections if they exist
|
| 465 |
+
- Check browser console for errors
|
| 466 |
+
|
| 467 |
+
#### Issue 8: "Import errors"
|
| 468 |
+
|
| 469 |
+
**Solution**:
|
| 470 |
+
```bash
|
| 471 |
+
# Reinstall all dependencies
|
| 472 |
+
pip install --upgrade -r requirements.txt
|
| 473 |
+
```
|
| 474 |
+
|
| 475 |
+
#### Issue 9: "Memory errors with large documents"
|
| 476 |
+
|
| 477 |
+
**Solution**:
|
| 478 |
+
- Reduce `MAX_TAGS_PER_CHUNK` in `.env`
|
| 479 |
+
- Process documents in smaller batches
|
| 480 |
+
- Increase system RAM or use smaller embedding models
|
| 481 |
+
|
| 482 |
+
#### Issue 10: "Port 7860 already in use"
|
| 483 |
+
|
| 484 |
+
**Solution**:
|
| 485 |
+
```bash
|
| 486 |
+
# Find process using port
|
| 487 |
+
lsof -i :7860 # Linux/macOS
|
| 488 |
+
netstat -ano | findstr :7860 # Windows
|
| 489 |
+
|
| 490 |
+
# Kill process or change port in .env
|
| 491 |
+
GRADIO_SERVER_PORT=7861
|
| 492 |
+
```
|
| 493 |
+
|
| 494 |
+
### Getting Help
|
| 495 |
+
|
| 496 |
+
1. **Check Logs**: Look at console output for error messages
|
| 497 |
+
2. **Enable Debug Logging**: Set `LOG_LEVEL=DEBUG` in `.env`
|
| 498 |
+
3. **Check Documentation**: See `README.md` and other `.md` files
|
| 499 |
+
4. **Verify Installation**: Run verification steps above
|
| 500 |
+
|
| 501 |
+
## Next Steps
|
| 502 |
+
|
| 503 |
+
After successful setup:
|
| 504 |
+
|
| 505 |
+
1. **Upload Sample Documents**: Use files in `sample_documents/`
|
| 506 |
+
2. **Run Sample Evaluation**: Use `sample_evaluation_queries.json`
|
| 507 |
+
3. **Explore Features**: Try all tabs in the UI
|
| 508 |
+
4. **Read Guides**:
|
| 509 |
+
- `EVALUATION_GUIDE.md` - How to run evaluation
|
| 510 |
+
- `SEARCH_COMPARE_GUIDE.md` - How to use Search & Compare
|
| 511 |
+
- `TAG_GENERATION_GUIDE.md` - Tag generation details
|
| 512 |
+
|
app.py
ADDED
|
@@ -0,0 +1,1552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import tempfile
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 8 |
+
import shutil
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import zipfile
|
| 12 |
+
from core.ingest import FlatTagChunker
|
| 13 |
+
from core.index import VectorStore
|
| 14 |
+
from core.retrieval import RAGManager
|
| 15 |
+
from core.eval import RAGEvaluator
|
| 16 |
+
from core.utils import generate_id
|
| 17 |
+
from core.comparison import RAGComparisonFramework
|
| 18 |
+
from core.visualization import RAGVisualizer
|
| 19 |
+
from core.report_generator import ReportGenerator
|
| 20 |
+
from core.session_manager import SessionManager
|
| 21 |
+
from core.session_rag import SessionAwareRAGManager
|
| 22 |
+
import os as _os
|
| 23 |
+
_OPENAI_ON = False
|
| 24 |
+
try:
|
| 25 |
+
from openai import OpenAI as _OpenAI
|
| 26 |
+
_OPENAI_ON = True if _os.getenv("OPENAI_API_KEY") else False
|
| 27 |
+
except Exception:
|
| 28 |
+
_OPENAI_ON = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Load environment variables from .env if present, then configure logging
|
| 32 |
+
load_dotenv()
|
| 33 |
+
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
| 34 |
+
logger = logging.getLogger("rag_app")
|
| 35 |
+
if os.getenv("OPENAI_API_KEY"):
|
| 36 |
+
logger.info("OpenAI API key detected. OpenAI-powered auto-detection is ENABLED.")
|
| 37 |
+
if os.getenv("OPENAI_MODEL"):
|
| 38 |
+
logger.info(f"OpenAI model: {os.getenv('OPENAI_MODEL')}")
|
| 39 |
+
else:
|
| 40 |
+
logger.info("OpenAI API key not set. Falling back to heuristic auto-detection.")
|
| 41 |
+
|
| 42 |
+
# Global variables
|
| 43 |
+
rag_manager = None
|
| 44 |
+
evaluator = None
|
| 45 |
+
comparator = None
|
| 46 |
+
visualizer = None
|
| 47 |
+
report_generator = None
|
| 48 |
+
session_manager = None
|
| 49 |
+
session_rag_manager = None
|
| 50 |
+
current_collection = "documents"
|
| 51 |
+
persist_directory = None
|
| 52 |
+
|
| 53 |
+
def initialize_system():
|
| 54 |
+
"""Initialize the RAG system"""
|
| 55 |
+
global rag_manager, evaluator, comparator, visualizer, report_generator, session_manager, session_rag_manager, persist_directory
|
| 56 |
+
|
| 57 |
+
# Try /data/chroma first (for HF Spaces), fallback to ./chroma_data
|
| 58 |
+
persist_dir = "/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data"
|
| 59 |
+
|
| 60 |
+
# Create directory with proper permissions, and check if we can write to it
|
| 61 |
+
try:
|
| 62 |
+
os.makedirs(persist_dir, exist_ok=True, mode=0o755)
|
| 63 |
+
# Test write permissions
|
| 64 |
+
test_file = os.path.join(persist_dir, ".test_write")
|
| 65 |
+
try:
|
| 66 |
+
with open(test_file, 'w') as f:
|
| 67 |
+
f.write("test")
|
| 68 |
+
os.remove(test_file)
|
| 69 |
+
except (PermissionError, OSError):
|
| 70 |
+
# If can't write to /data/chroma, use ./chroma_data
|
| 71 |
+
persist_dir = "./chroma_data"
|
| 72 |
+
os.makedirs(persist_dir, exist_ok=True, mode=0o755)
|
| 73 |
+
except (PermissionError, OSError) as e:
|
| 74 |
+
# If even ./chroma_data fails, try current directory
|
| 75 |
+
persist_dir = "./chroma_data"
|
| 76 |
+
os.makedirs(persist_dir, exist_ok=True, mode=0o755)
|
| 77 |
+
|
| 78 |
+
persist_directory = persist_dir
|
| 79 |
+
|
| 80 |
+
rag_manager = RAGManager(persist_directory=persist_dir)
|
| 81 |
+
evaluator = RAGEvaluator(rag_manager)
|
| 82 |
+
comparator = RAGComparisonFramework(evaluator)
|
| 83 |
+
visualizer = RAGVisualizer()
|
| 84 |
+
report_generator = ReportGenerator()
|
| 85 |
+
|
| 86 |
+
# Initialize session manager
|
| 87 |
+
session_timeout = int(os.getenv("SESSION_TIMEOUT", 3600))
|
| 88 |
+
session_manager = SessionManager(base_persist_dir=persist_dir, session_timeout=session_timeout)
|
| 89 |
+
session_rag_manager = SessionAwareRAGManager(rag_manager, session_manager)
|
| 90 |
+
|
| 91 |
+
return f"System initialized successfully! Using persist directory: {persist_dir}"
|
| 92 |
+
|
| 93 |
+
def reset_index() -> str:
|
| 94 |
+
"""Clear Chroma persistence and reinitialize the vector store."""
|
| 95 |
+
global rag_manager, evaluator, comparator, visualizer, report_generator, session_manager, session_rag_manager, persist_directory
|
| 96 |
+
try:
|
| 97 |
+
dir_path = persist_directory or ("/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data")
|
| 98 |
+
if os.path.exists(dir_path):
|
| 99 |
+
shutil.rmtree(dir_path, ignore_errors=True)
|
| 100 |
+
os.makedirs(dir_path, exist_ok=True, mode=0o755)
|
| 101 |
+
persist_directory = dir_path
|
| 102 |
+
rag_manager = RAGManager(persist_directory=dir_path)
|
| 103 |
+
evaluator = RAGEvaluator(rag_manager)
|
| 104 |
+
comparator = RAGComparisonFramework(evaluator)
|
| 105 |
+
visualizer = RAGVisualizer()
|
| 106 |
+
report_generator = ReportGenerator()
|
| 107 |
+
session_timeout = int(os.getenv("SESSION_TIMEOUT", 3600))
|
| 108 |
+
session_manager = SessionManager(base_persist_dir=dir_path, session_timeout=session_timeout)
|
| 109 |
+
session_rag_manager = SessionAwareRAGManager(rag_manager, session_manager)
|
| 110 |
+
return f"Index reset complete. Using fresh directory: {dir_path}"
|
| 111 |
+
except Exception as ex:
|
| 112 |
+
return f"Failed to reset index: {ex}"
|
| 113 |
+
|
| 114 |
+
def upload_documents(files: List[str], language: str, user_tags: Optional[str] = None, use_flat_tags: bool = True, collection_name: str = None, progress: Any = None) -> Tuple[str, List[Dict[str, Any]], Dict[str, Any], List[Dict[str, Any]]]:
|
| 115 |
+
"""Upload and process documents.
|
| 116 |
+
Returns: (status_text, per_file_summaries, collection_stats, chunk_rows)
|
| 117 |
+
per_file_summaries: [{filename, chunks, language, tags}]
|
| 118 |
+
"""
|
| 119 |
+
global rag_manager, persist_directory, current_collection
|
| 120 |
+
|
| 121 |
+
if not files:
|
| 122 |
+
return "No files provided!", [], {}, []
|
| 123 |
+
|
| 124 |
+
# Ensure system is initialized
|
| 125 |
+
if not rag_manager:
|
| 126 |
+
initialize_system()
|
| 127 |
+
|
| 128 |
+
# Use provided collection or default
|
| 129 |
+
if collection_name:
|
| 130 |
+
current_collection = collection_name
|
| 131 |
+
|
| 132 |
+
# Parse user tags from comma-separated string
|
| 133 |
+
user_tags_list = []
|
| 134 |
+
if user_tags and user_tags.strip():
|
| 135 |
+
user_tags_list = [tag.strip() for tag in user_tags.split(',') if tag.strip()]
|
| 136 |
+
|
| 137 |
+
# Choose chunker based on tag mode
|
| 138 |
+
# Use FlatTagChunker for flat tagging
|
| 139 |
+
chunker = FlatTagChunker()
|
| 140 |
+
|
| 141 |
+
all_chunks = []
|
| 142 |
+
per_file_summaries: List[Dict[str, Any]] = []
|
| 143 |
+
chunk_rows: List[Dict[str, Any]] = []
|
| 144 |
+
|
| 145 |
+
processed_count = 0
|
| 146 |
+
errors: List[str] = []
|
| 147 |
+
total = len(files)
|
| 148 |
+
|
| 149 |
+
for idx, file_path in enumerate(files, start=1):
|
| 150 |
+
if progress:
|
| 151 |
+
try:
|
| 152 |
+
progress(idx, total=total, desc=f"Processing {idx}/{total}: {os.path.basename(file_path)}")
|
| 153 |
+
except Exception:
|
| 154 |
+
pass
|
| 155 |
+
try:
|
| 156 |
+
chunks = chunker.chunk_document(file_path, language=language, user_tags=user_tags_list)
|
| 157 |
+
if chunks:
|
| 158 |
+
# Aggregate per-file metadata from chunks (majority vote)
|
| 159 |
+
from collections import Counter
|
| 160 |
+
langs = Counter([c.metadata.get('lang') for c in chunks if c.metadata.get('lang')])
|
| 161 |
+
|
| 162 |
+
# Count tags (including user tags)
|
| 163 |
+
tag_count = sum(len(c.metadata.get('tags', [])) for c in chunks)
|
| 164 |
+
|
| 165 |
+
# Count user tags vs auto tags (first tags are user tags)
|
| 166 |
+
user_tag_count = len(user_tags_list) if user_tags_list else 0
|
| 167 |
+
|
| 168 |
+
per_file_summaries.append({
|
| 169 |
+
'Filename': os.path.basename(file_path),
|
| 170 |
+
'Chunks': len(chunks),
|
| 171 |
+
'Language': (langs.most_common(1)[0][0] if langs else None),
|
| 172 |
+
'Tags': tag_count,
|
| 173 |
+
'User Tags': user_tag_count
|
| 174 |
+
})
|
| 175 |
+
# Prepare per-chunk preview rows
|
| 176 |
+
for c in chunks:
|
| 177 |
+
md = c.metadata or {}
|
| 178 |
+
row = {
|
| 179 |
+
'Filename': os.path.basename(file_path),
|
| 180 |
+
'Language': md.get('lang'),
|
| 181 |
+
'Tags': ', '.join(md.get('tags', [])[:5]) # Show first 5 tags
|
| 182 |
+
}
|
| 183 |
+
row['Preview'] = (c.content[:160] + '...') if c.content else ''
|
| 184 |
+
chunk_rows.append(row)
|
| 185 |
+
|
| 186 |
+
all_chunks.extend(chunks)
|
| 187 |
+
processed_count += 1
|
| 188 |
+
else:
|
| 189 |
+
errors.append(f"Warning: {os.path.basename(file_path)} produced no chunks")
|
| 190 |
+
except Exception as e:
|
| 191 |
+
error_msg = f"{os.path.basename(file_path)}: {str(e)}"
|
| 192 |
+
errors.append(error_msg)
|
| 193 |
+
|
| 194 |
+
# Index if any chunk present
|
| 195 |
+
vector_store = rag_manager.vector_store
|
| 196 |
+
stats: Dict[str, Any] = {"document_count": 0, "collection_name": current_collection}
|
| 197 |
+
if all_chunks:
|
| 198 |
+
vector_store.add_documents(current_collection, all_chunks)
|
| 199 |
+
stats = vector_store.get_collection_stats(current_collection)
|
| 200 |
+
|
| 201 |
+
# Build result message
|
| 202 |
+
total_tags = sum(len(c.metadata.get('tags', [])) for c in all_chunks)
|
| 203 |
+
status_lines = [
|
| 204 |
+
f"Processed {processed_count}/{total} files",
|
| 205 |
+
f"Indexed chunks: {len(all_chunks)}",
|
| 206 |
+
f"Generated tags: {total_tags}"
|
| 207 |
+
]
|
| 208 |
+
if errors:
|
| 209 |
+
status_lines.append("\nErrors/Warnings:\n" + "\n".join(f"- {e}" for e in errors))
|
| 210 |
+
|
| 211 |
+
return "\n".join(status_lines), per_file_summaries, stats, chunk_rows
|
| 212 |
+
|
| 213 |
+
def init_session(client_session_id: Optional[str] = None) -> Dict[str, Any]:
|
| 214 |
+
"""
|
| 215 |
+
Initialize or restore a user session.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
client_session_id: Optional session ID from client (localStorage)
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Dict with session_id and collection_name
|
| 222 |
+
"""
|
| 223 |
+
global session_manager
|
| 224 |
+
if not session_manager:
|
| 225 |
+
initialize_system()
|
| 226 |
+
|
| 227 |
+
# Use get_or_create_session which checks ChromaDB for existing collections
|
| 228 |
+
# This handles server restarts - if ChromaDB collection exists, restore session
|
| 229 |
+
session = session_manager.get_or_create_session(session_id=client_session_id)
|
| 230 |
+
|
| 231 |
+
if client_session_id and session.session_id == client_session_id:
|
| 232 |
+
# Successfully restored existing session
|
| 233 |
+
logger.info(f"Restored session {client_session_id} (collection: {session.collection_name})")
|
| 234 |
+
return {"session_id": session.session_id, "collection_name": session.collection_name}
|
| 235 |
+
else:
|
| 236 |
+
# Created new session
|
| 237 |
+
logger.info(f"Created new session {session.session_id} (collection: {session.collection_name})")
|
| 238 |
+
return {"session_id": session.session_id, "collection_name": session.collection_name, "new_session": True}
|
| 239 |
+
|
| 240 |
+
def get_document_count(session_state: Dict[str, Any]) -> str:
|
| 241 |
+
"""Get total document count from the current session's collection - returns only the number"""
|
| 242 |
+
global rag_manager
|
| 243 |
+
if not session_state or not session_state.get("collection_name"):
|
| 244 |
+
return "0"
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
if not rag_manager:
|
| 248 |
+
initialize_system()
|
| 249 |
+
|
| 250 |
+
collection_name = session_state.get("collection_name")
|
| 251 |
+
stats = rag_manager.vector_store.get_collection_stats(collection_name)
|
| 252 |
+
doc_count = stats.get("document_count", 0)
|
| 253 |
+
return str(doc_count)
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.warning(f"Failed to get document count: {e}")
|
| 256 |
+
return "0"
|
| 257 |
+
|
| 258 |
+
def build_with_session(files: List[Any], language: str, user_tags: str, session_state: Dict[str, Any], progress=gr.Progress()) -> Tuple[Dict[str, Any], str, pd.DataFrame, pd.DataFrame, str]:
|
| 259 |
+
"""Build RAG index with session support and progress tracking"""
|
| 260 |
+
# Initialize or get session
|
| 261 |
+
if not session_state or not session_state.get("session_id"):
|
| 262 |
+
session_state = init_session()
|
| 263 |
+
|
| 264 |
+
# Refresh session to prevent expiration (get_session updates access time)
|
| 265 |
+
global session_manager
|
| 266 |
+
if session_manager and session_state.get("session_id"):
|
| 267 |
+
session_manager.get_session(session_state["session_id"])
|
| 268 |
+
|
| 269 |
+
# Build index using session collection with progress (always use flat tags)
|
| 270 |
+
status, stats_df, chunks_df = build_rag_index(
|
| 271 |
+
files, language, user_tags, use_flat_tags=True, collection_name=session_state.get("collection_name")
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Extract tags for visualization
|
| 275 |
+
tag_list_str = ""
|
| 276 |
+
if chunks_df is not None and not chunks_df.empty and 'Tags' in chunks_df.columns:
|
| 277 |
+
all_tags = []
|
| 278 |
+
for tags_str in chunks_df['Tags'].dropna():
|
| 279 |
+
if tags_str:
|
| 280 |
+
all_tags.extend([t.strip() for t in str(tags_str).split(',')])
|
| 281 |
+
from collections import Counter
|
| 282 |
+
tag_counts = Counter(all_tags)
|
| 283 |
+
top_tags = [f"{tag} ({count})" for tag, count in tag_counts.most_common(20)]
|
| 284 |
+
tag_list_str = "\n".join([f"- {tag}" for tag in top_tags]) if top_tags else "No tags generated"
|
| 285 |
+
|
| 286 |
+
# Document count will be updated separately after processing completes to avoid progress bar
|
| 287 |
+
|
| 288 |
+
return session_state, status, stats_df, chunks_df, tag_list_str
|
| 289 |
+
|
| 290 |
+
def build_rag_index(files: List[Any], language: str, user_tags: str = "", use_flat_tags: bool = True, collection_name: str = None) -> Tuple[str, pd.DataFrame, pd.DataFrame]:
|
| 291 |
+
"""Build RAG index from uploaded files"""
|
| 292 |
+
if not files:
|
| 293 |
+
return "No files provided!", None, None
|
| 294 |
+
|
| 295 |
+
# Gradio file objects already contain the full path in .name property
|
| 296 |
+
# No need to prepend /tmp/ - just use the path directly
|
| 297 |
+
file_paths = []
|
| 298 |
+
for file in files:
|
| 299 |
+
# Get the file path - Gradio provides it as .name or as a string
|
| 300 |
+
if isinstance(file, str):
|
| 301 |
+
file_path = file
|
| 302 |
+
elif hasattr(file, 'name') and file.name:
|
| 303 |
+
# file.name already contains the full path (e.g., /tmp/gradio/.../filename.txt)
|
| 304 |
+
file_path = file.name
|
| 305 |
+
else:
|
| 306 |
+
# Fallback for edge cases
|
| 307 |
+
return f"Error: Unable to get file path from uploaded file", None
|
| 308 |
+
|
| 309 |
+
# Normalize the path to handle any double slashes
|
| 310 |
+
file_path = os.path.normpath(file_path)
|
| 311 |
+
|
| 312 |
+
# Ensure the file exists
|
| 313 |
+
if not os.path.exists(file_path):
|
| 314 |
+
return f"Error: File not found at {file_path}", None
|
| 315 |
+
|
| 316 |
+
file_paths.append(file_path)
|
| 317 |
+
|
| 318 |
+
# Normalize "Auto" to None for auto-detection downstream
|
| 319 |
+
norm_language = None if not language or str(language).lower() == 'auto' else language
|
| 320 |
+
|
| 321 |
+
# Process documents (progress provided by gradio)
|
| 322 |
+
status_text, file_summaries, stats, chunk_rows = upload_documents(
|
| 323 |
+
file_paths, norm_language, user_tags,
|
| 324 |
+
use_flat_tags=use_flat_tags, collection_name=collection_name, progress=None
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Build per-file dataframe (no totals row)
|
| 328 |
+
per_file_df = pd.DataFrame(file_summaries) if file_summaries else pd.DataFrame(columns=['Filename','Chunks','Language','Tags','User Tags'])
|
| 329 |
+
chunks_df = pd.DataFrame(chunk_rows) if chunk_rows else pd.DataFrame(columns=['Filename','Language','Tags','Preview'])
|
| 330 |
+
return status_text, per_file_df, chunks_df
|
| 331 |
+
|
| 332 |
+
def _format_result(result, method_name, query: str, k: int, used_tags=None):
|
| 333 |
+
"""Format a single retrieval result for display"""
|
| 334 |
+
if not result or not result.sources:
|
| 335 |
+
return f"**{method_name}**: No results found", ""
|
| 336 |
+
|
| 337 |
+
# Generate answer
|
| 338 |
+
answer = _llm_answer(query, result.sources[:3]) # Use top 3 for answer generation
|
| 339 |
+
|
| 340 |
+
content_lines = [f"**{method_name}** (Latency: {result.latency:.3f}s)\n"]
|
| 341 |
+
content_lines.append(f"**Answer**: {answer}\n")
|
| 342 |
+
content_lines.append(f"\n**Retrieved Documents** ({len(result.sources)} total):\n")
|
| 343 |
+
|
| 344 |
+
# Get tags from retrieved documents
|
| 345 |
+
doc_tags = set()
|
| 346 |
+
for src in result.sources[:k]:
|
| 347 |
+
meta = src.get('metadata', {})
|
| 348 |
+
tags_from_meta = meta.get('tags', [])
|
| 349 |
+
# Handle both list and comma-separated string
|
| 350 |
+
if isinstance(tags_from_meta, str):
|
| 351 |
+
tags_from_meta = [t.strip() for t in tags_from_meta.split(',') if t.strip()]
|
| 352 |
+
doc_tags.update(tags_from_meta)
|
| 353 |
+
|
| 354 |
+
if used_tags:
|
| 355 |
+
content_lines.append(f"**Tags Used**: {', '.join(used_tags)}\n")
|
| 356 |
+
if doc_tags:
|
| 357 |
+
content_lines.append(f"**Tags in Results**: {', '.join(list(doc_tags)[:10])}\n")
|
| 358 |
+
|
| 359 |
+
content_lines.append("\n**Top Results**:\n")
|
| 360 |
+
for i, src in enumerate(result.sources[:k], 1):
|
| 361 |
+
score = src.get('score', 0)
|
| 362 |
+
meta = src.get('metadata', {})
|
| 363 |
+
source_name = meta.get('source_name', 'unknown')
|
| 364 |
+
src_tags = meta.get('tags', [])
|
| 365 |
+
if isinstance(src_tags, str):
|
| 366 |
+
src_tags = [t.strip() for t in src_tags.split(',') if t.strip()]
|
| 367 |
+
tag_str = f" [Tags: {', '.join(src_tags[:3])}]" if src_tags else ""
|
| 368 |
+
content_lines.append(f"{i}. [{source_name}] (Score: {score:.3f}){tag_str}")
|
| 369 |
+
content_lines.append(f" {src.get('content', '')[:120]}...\n")
|
| 370 |
+
|
| 371 |
+
return "\n".join(content_lines), answer
|
| 372 |
+
|
| 373 |
+
def search_all_methods(query: str, k: int, tags: str, tag_operator: str, vector_weight: float, tag_weight: float, session_state: Dict[str, Any]) -> Tuple[str, str, str, str, str, Dict[str, Any]]:
|
| 374 |
+
"""Search using all 4 retrieval methods - processes sequentially to reduce load"""
|
| 375 |
+
global session_rag_manager, rag_manager
|
| 376 |
+
|
| 377 |
+
if not session_rag_manager or not rag_manager:
|
| 378 |
+
initialize_system()
|
| 379 |
+
|
| 380 |
+
# Get or refresh session
|
| 381 |
+
if not session_state or not session_state.get("session_id"):
|
| 382 |
+
session_state = init_session()
|
| 383 |
+
else:
|
| 384 |
+
# Refresh session to prevent expiration (get_session updates access time)
|
| 385 |
+
session_manager.get_session(session_state["session_id"])
|
| 386 |
+
|
| 387 |
+
# Get session-aware RAG manager
|
| 388 |
+
rag = session_rag_manager.get_rag(session_state["session_id"])
|
| 389 |
+
|
| 390 |
+
# Parse tags
|
| 391 |
+
tag_list = [t.strip() for t in tags.split(',') if t.strip()] if tags else []
|
| 392 |
+
|
| 393 |
+
results = {}
|
| 394 |
+
base_text = "**Base RAG**: Processing..."
|
| 395 |
+
tag_text = "**Tag Filter RAG**: Waiting..."
|
| 396 |
+
hybrid_text = "**Hybrid RAG**: Waiting..."
|
| 397 |
+
rerank_text = "**Hybrid Rerank RAG**: Waiting..."
|
| 398 |
+
|
| 399 |
+
try:
|
| 400 |
+
# 1. Base RAG (process first)
|
| 401 |
+
logger.info("Processing Base RAG...")
|
| 402 |
+
base_result = rag.base_rag.retrieve(query, k)
|
| 403 |
+
results['base'] = base_result
|
| 404 |
+
base_text, base_answer = _format_result(base_result, "Base RAG", query, k)
|
| 405 |
+
|
| 406 |
+
# 2. Tag Filter RAG (process second)
|
| 407 |
+
logger.info("Processing Tag Filter RAG...")
|
| 408 |
+
tag_result = rag.tag_filter_rag.retrieve(query, k, tags=tag_list, tag_operator=tag_operator)
|
| 409 |
+
results['tag'] = tag_result
|
| 410 |
+
tag_text, tag_answer = _format_result(tag_result, f"Tag Filter RAG ({tag_operator})", query, k, tag_list)
|
| 411 |
+
|
| 412 |
+
# 3. Hybrid RAG (process third)
|
| 413 |
+
logger.info("Processing Hybrid RAG...")
|
| 414 |
+
hybrid_result = rag.hybrid_rag.retrieve(query, k, tags=tag_list, vector_weight=vector_weight, tag_weight=tag_weight)
|
| 415 |
+
results['hybrid'] = hybrid_result
|
| 416 |
+
hybrid_text, hybrid_answer = _format_result(hybrid_result, f"Hybrid RAG (V:{vector_weight:.1f}, T:{tag_weight:.1f})", query, k, tag_list)
|
| 417 |
+
|
| 418 |
+
# 4. Hybrid Rerank RAG (process last - most expensive)
|
| 419 |
+
logger.info("Processing Hybrid Rerank RAG...")
|
| 420 |
+
rerank_result = rag.hybrid_rerank_rag.retrieve(query, k, tags=tag_list, vector_weight=vector_weight, tag_weight=tag_weight)
|
| 421 |
+
results['rerank'] = rerank_result
|
| 422 |
+
rerank_text, rerank_answer = _format_result(rerank_result, "Hybrid Rerank RAG", query, k, tag_list)
|
| 423 |
+
|
| 424 |
+
except Exception as e:
|
| 425 |
+
error_msg = f"Error during retrieval: {str(e)}"
|
| 426 |
+
logger.error(f"Search error: {error_msg}", exc_info=True)
|
| 427 |
+
return error_msg, error_msg, error_msg, error_msg, f"Error: {error_msg}", session_state
|
| 428 |
+
|
| 429 |
+
# Create quick comparison metrics
|
| 430 |
+
metrics = []
|
| 431 |
+
for method_name, result, answer in [
|
| 432 |
+
("Base RAG", results.get('base'), base_answer if 'base_answer' in locals() else None),
|
| 433 |
+
("Tag Filter RAG", results.get('tag'), tag_answer if 'tag_answer' in locals() else None),
|
| 434 |
+
("Hybrid RAG", results.get('hybrid'), hybrid_answer if 'hybrid_answer' in locals() else None),
|
| 435 |
+
("Hybrid Rerank RAG", results.get('rerank'), rerank_answer if 'rerank_answer' in locals() else None)
|
| 436 |
+
]:
|
| 437 |
+
if result and result.sources:
|
| 438 |
+
metrics.append(f"- **{method_name}**: {len(result.sources)} docs, {result.latency:.3f}s, Score: {result.sources[0].get('score', 0):.3f}")
|
| 439 |
+
else:
|
| 440 |
+
metrics.append(f"- **{method_name}**: No results")
|
| 441 |
+
|
| 442 |
+
summary = f"""
|
| 443 |
+
### Query: {query}
|
| 444 |
+
### Quick Comparison:
|
| 445 |
+
{chr(10).join(metrics)}
|
| 446 |
+
### Tags: {', '.join(tag_list) if tag_list else 'None'}
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
logger.info("All search methods completed successfully")
|
| 450 |
+
return base_text, tag_text, hybrid_text, rerank_text, summary, session_state
|
| 451 |
+
|
| 452 |
+
def _llm_answer(user_message: str, contexts: List[Dict[str, Any]]) -> str:
|
| 453 |
+
"""Generate a natural, human-like answer grounded in contexts.
|
| 454 |
+
Uses OpenAI if configured; otherwise produces a conversational fallback.
|
| 455 |
+
"""
|
| 456 |
+
# Build a compact context string
|
| 457 |
+
ctx_blocks = []
|
| 458 |
+
for i, c in enumerate(contexts, 1):
|
| 459 |
+
src = c.get('metadata', {}).get('source_name', 'unknown')
|
| 460 |
+
snippet = (c.get('content', '') or '')[:400]
|
| 461 |
+
ctx_blocks.append(f"[{i}] ({src}) {snippet}")
|
| 462 |
+
ctx_text = "\n\n".join(ctx_blocks)
|
| 463 |
+
|
| 464 |
+
# Prefer OpenAI
|
| 465 |
+
if os.getenv("OPENAI_API_KEY") and '_OpenAI' in globals() and _OpenAI is not None:
|
| 466 |
+
try:
|
| 467 |
+
client = _OpenAI()
|
| 468 |
+
system_prompt = (
|
| 469 |
+
"You are a helpful, professional assistant. Answer in a warm, natural, and concise tone. "
|
| 470 |
+
"ALWAYS ground the answer ONLY in the provided contexts. If information is missing, say so. "
|
| 471 |
+
"Style: Start with a clear 1-2 sentence answer. Then, if helpful, add 2-5 short bullet points with key facts. "
|
| 472 |
+
"Avoid hedging, avoid citations inline, avoid repeating the question."
|
| 473 |
+
)
|
| 474 |
+
content = (
|
| 475 |
+
f"User question:\n{user_message}\n\n"
|
| 476 |
+
f"Contexts (each begins with [n]):\n{ctx_text}\n\n"
|
| 477 |
+
"Write a natural, human-like response as specified."
|
| 478 |
+
)
|
| 479 |
+
resp = client.chat.completions.create(
|
| 480 |
+
model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
|
| 481 |
+
messages=[
|
| 482 |
+
{"role": "system", "content": system_prompt},
|
| 483 |
+
{"role": "user", "content": content},
|
| 484 |
+
],
|
| 485 |
+
temperature=0.2,
|
| 486 |
+
)
|
| 487 |
+
return resp.choices[0].message.content.strip()
|
| 488 |
+
except Exception:
|
| 489 |
+
pass
|
| 490 |
+
|
| 491 |
+
# Fallback heuristic: conversational synthesis from top snippets
|
| 492 |
+
if contexts:
|
| 493 |
+
snippets = []
|
| 494 |
+
for src in contexts[:3]:
|
| 495 |
+
txt = (src.get('content') or '').strip()
|
| 496 |
+
if txt:
|
| 497 |
+
snippets.append(txt)
|
| 498 |
+
joined = " ".join(snippets)
|
| 499 |
+
# Naive sentence split
|
| 500 |
+
sentences = [s.strip() for s in joined.replace('\n', ' ').split('.') if s.strip()]
|
| 501 |
+
bullets = sentences[:4]
|
| 502 |
+
lead = "Here’s what the documents say:"
|
| 503 |
+
if bullets:
|
| 504 |
+
bullets_text = "\n".join([f"- {b}." for b in bullets])
|
| 505 |
+
return f"{lead}\n\n{bullets_text}"
|
| 506 |
+
return "I don't have enough information to answer that yet."
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def chat_with_rag(message: str, history: List[Dict[str, str]], pipeline: str, k: int, tags: str, use_tags: bool, session_state: Dict[str, Any]) -> Tuple[str, List[Dict[str, str]], Dict[str, Any], str]:
|
| 510 |
+
"""Chat interface with RAG: choose one pipeline, retrieve, then generate an LLM answer."""
|
| 511 |
+
global session_rag_manager, rag_manager, session_manager
|
| 512 |
+
|
| 513 |
+
if not session_rag_manager or not rag_manager:
|
| 514 |
+
return "System not initialized! Please build the RAG index first.", history, session_state or {}, ""
|
| 515 |
+
|
| 516 |
+
# Get or refresh session
|
| 517 |
+
if not session_state or not session_state.get("session_id"):
|
| 518 |
+
session_state = init_session()
|
| 519 |
+
else:
|
| 520 |
+
# Refresh session to prevent expiration (get_session updates access time)
|
| 521 |
+
session_manager.get_session(session_state["session_id"])
|
| 522 |
+
|
| 523 |
+
rag = session_rag_manager.get_rag(session_state["session_id"])
|
| 524 |
+
|
| 525 |
+
# Parse tags - only use if toggle is enabled
|
| 526 |
+
tag_list = [t.strip() for t in tags.split(',') if t.strip()] if (use_tags and tags) else []
|
| 527 |
+
|
| 528 |
+
# Retrieve with chosen pipeline
|
| 529 |
+
filters_note = ""
|
| 530 |
+
result = None
|
| 531 |
+
|
| 532 |
+
try:
|
| 533 |
+
if pipeline == "base_rag":
|
| 534 |
+
result = rag.base_rag.retrieve(message, k)
|
| 535 |
+
elif pipeline == "tag_filter_rag":
|
| 536 |
+
result = rag.tag_filter_rag.retrieve(message, k, tags=tag_list, tag_operator="OR")
|
| 537 |
+
if tag_list:
|
| 538 |
+
filters_note = f" (tags: {', '.join(tag_list)})"
|
| 539 |
+
elif pipeline == "hybrid_rag":
|
| 540 |
+
result = rag.hybrid_rag.retrieve(message, k, tags=tag_list, vector_weight=0.7, tag_weight=0.3)
|
| 541 |
+
if tag_list:
|
| 542 |
+
filters_note = f" (tags: {', '.join(tag_list)})"
|
| 543 |
+
elif pipeline == "hybrid_rerank_rag":
|
| 544 |
+
result = rag.hybrid_rerank_rag.retrieve(message, k, tags=tag_list, vector_weight=0.7, tag_weight=0.3)
|
| 545 |
+
if tag_list:
|
| 546 |
+
filters_note = f" (tags: {', '.join(tag_list)})"
|
| 547 |
+
else:
|
| 548 |
+
return f"Unknown pipeline: {pipeline}", history, session_state or {}, ""
|
| 549 |
+
except Exception as e:
|
| 550 |
+
return f"Error during retrieval: {str(e)}", history, session_state or {}, ""
|
| 551 |
+
|
| 552 |
+
if not result:
|
| 553 |
+
return "No results retrieved", history, session_state or {}, ""
|
| 554 |
+
|
| 555 |
+
# Generate grounded answer
|
| 556 |
+
answer = _llm_answer(message, result.sources)
|
| 557 |
+
|
| 558 |
+
# Format sources with tags for display
|
| 559 |
+
sources_display = []
|
| 560 |
+
sources_display.append(f"### 📎 Source Documents ({len(result.sources)} total)\n")
|
| 561 |
+
for i, s in enumerate(result.sources[:k], 1):
|
| 562 |
+
meta = s.get('metadata', {})
|
| 563 |
+
src_name = meta.get('source_name', 'unknown')
|
| 564 |
+
src_tags = meta.get('tags', [])
|
| 565 |
+
tag_str = f" **Tags:** {', '.join(src_tags[:5])}" if src_tags else ""
|
| 566 |
+
score = s.get('score', 0)
|
| 567 |
+
sources_display.append(f"{i}. **{src_name}** (Score: {score:.3f}){tag_str}")
|
| 568 |
+
sources_display.append(f" {s.get('content', '')[:150]}...\n")
|
| 569 |
+
|
| 570 |
+
sources_text = "\n".join(sources_display)
|
| 571 |
+
|
| 572 |
+
# Sources list (compact for chat)
|
| 573 |
+
src_lines = []
|
| 574 |
+
for s in result.sources[:k]:
|
| 575 |
+
src_name = s.get('metadata', {}).get('source_name', 'unknown')
|
| 576 |
+
src_tags = s.get('metadata', {}).get('tags', [])
|
| 577 |
+
tag_str = f" [{', '.join(src_tags[:3])}]" if src_tags else ""
|
| 578 |
+
src_lines.append(f"• {src_name}{tag_str}")
|
| 579 |
+
|
| 580 |
+
pipeline_names = {
|
| 581 |
+
"base_rag": "Base-RAG",
|
| 582 |
+
"tag_filter_rag": "Tag Filter RAG",
|
| 583 |
+
"hybrid_rag": "Hybrid RAG",
|
| 584 |
+
"hybrid_rerank_rag": "Hybrid Rerank RAG"
|
| 585 |
+
}
|
| 586 |
+
pipeline_name = pipeline_names.get(pipeline, pipeline)
|
| 587 |
+
|
| 588 |
+
response = (
|
| 589 |
+
f"{answer}\n\n"
|
| 590 |
+
f"─────────────────────────────────────\n"
|
| 591 |
+
f"📎 Sources ({pipeline_name}{filters_note}):\n" + ("\n".join(src_lines) or "(none)")
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
history.append({"role": "user", "content": message})
|
| 595 |
+
history.append({"role": "assistant", "content": response})
|
| 596 |
+
return "", history, session_state, sources_text
|
| 597 |
+
|
| 598 |
+
def run_evaluation(queries_json: str, output_filename: str, user_satisfaction_json: str = None, session_state: Dict[str, Any] = None) -> Tuple[str, pd.DataFrame, Dict[str, Any], List[Dict[str, Any]]]:
|
| 599 |
+
"""Run quantitative evaluation with all pipelines"""
|
| 600 |
+
global evaluator, comparator, visualizer, report_generator, session_rag_manager, session_manager, rag_manager
|
| 601 |
+
|
| 602 |
+
if not evaluator:
|
| 603 |
+
return "System not initialized!", None, None, None
|
| 604 |
+
|
| 605 |
+
# Get or refresh session
|
| 606 |
+
if not session_state or not session_state.get("session_id"):
|
| 607 |
+
session_state = init_session()
|
| 608 |
+
else:
|
| 609 |
+
# Refresh session to prevent expiration
|
| 610 |
+
if session_manager and session_state.get("session_id"):
|
| 611 |
+
session_manager.get_session(session_state["session_id"])
|
| 612 |
+
|
| 613 |
+
# Use session-aware RAG manager for evaluation
|
| 614 |
+
if session_rag_manager and session_state and session_state.get("session_id"):
|
| 615 |
+
rag = session_rag_manager.get_rag(session_state["session_id"])
|
| 616 |
+
# Create a temporary evaluator with the session RAG manager
|
| 617 |
+
from core.eval import RAGEvaluator
|
| 618 |
+
evaluator = RAGEvaluator(rag)
|
| 619 |
+
|
| 620 |
+
try:
|
| 621 |
+
queries = json.loads(queries_json)
|
| 622 |
+
except json.JSONDecodeError as e:
|
| 623 |
+
return f"Invalid JSON format: {str(e)}", None, None, None
|
| 624 |
+
|
| 625 |
+
# Parse user satisfaction scores if provided
|
| 626 |
+
user_satisfaction = None
|
| 627 |
+
if user_satisfaction_json:
|
| 628 |
+
try:
|
| 629 |
+
user_satisfaction = json.loads(user_satisfaction_json)
|
| 630 |
+
except json.JSONDecodeError:
|
| 631 |
+
pass
|
| 632 |
+
|
| 633 |
+
# Run evaluation with all pipelines
|
| 634 |
+
df, summary_dict, raw_results = evaluator.batch_evaluate(
|
| 635 |
+
queries, output_filename,
|
| 636 |
+
pipelines=['base_rag', 'tag_filter_rag', 'hybrid_rag', 'hybrid_rerank_rag'],
|
| 637 |
+
user_satisfaction=user_satisfaction
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# Generate summary statistics
|
| 641 |
+
summary = df.groupby(['pipeline', 'k']).agg({
|
| 642 |
+
'precision_at_k': 'mean',
|
| 643 |
+
'ndcg_at_k': 'mean',
|
| 644 |
+
'hit_at_k': 'mean',
|
| 645 |
+
'mrr': 'mean',
|
| 646 |
+
'semantic_similarity': 'mean',
|
| 647 |
+
'latency': 'mean',
|
| 648 |
+
'retrieved_count': 'mean'
|
| 649 |
+
}).reset_index()
|
| 650 |
+
|
| 651 |
+
if 'user_satisfaction' in df.columns:
|
| 652 |
+
summary['user_satisfaction'] = df.groupby(['pipeline', 'k'])['user_satisfaction'].mean().reset_index()['user_satisfaction']
|
| 653 |
+
|
| 654 |
+
# Create comparison plot data (for backward compatibility)
|
| 655 |
+
plot_data = pd.DataFrame({
|
| 656 |
+
'k': summary[summary['pipeline'] == 'base_rag']['k'],
|
| 657 |
+
'base_rag_hit@k': summary[summary['pipeline'] == 'base_rag']['hit_at_k'],
|
| 658 |
+
'tag_filter_rag_hit@k': summary[summary['pipeline'] == 'tag_filter_rag']['hit_at_k'],
|
| 659 |
+
'hybrid_rag_hit@k': summary[summary['pipeline'] == 'hybrid_rag']['hit_at_k'],
|
| 660 |
+
'hybrid_rerank_rag_hit@k': summary[summary['pipeline'] == 'hybrid_rerank_rag']['hit_at_k']
|
| 661 |
+
})
|
| 662 |
+
|
| 663 |
+
summary_dict = {
|
| 664 |
+
'total_queries': len(queries),
|
| 665 |
+
'pipelines': ['base_rag', 'tag_filter_rag', 'hybrid_rag', 'hybrid_rerank_rag'],
|
| 666 |
+
'summary_stats': summary.to_dict('records'),
|
| 667 |
+
'plot_data': plot_data.to_dict('records')
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
return f"Evaluation completed! Processed {len(queries)} queries. Results saved to {output_filename}", df, summary_dict, raw_results
|
| 671 |
+
|
| 672 |
+
def run_evaluation_with_viz(queries_json: str, output_filename: str, user_satisfaction_json: str = None, session_state: Dict[str, Any] = None) -> Tuple[str, pd.DataFrame, str, str, str, Dict[str, str], str]:
|
| 673 |
+
"""Run evaluation with visualization and report generation"""
|
| 674 |
+
global visualizer, report_generator
|
| 675 |
+
|
| 676 |
+
status, df, summary, raw_results = run_evaluation(queries_json, output_filename, user_satisfaction_json, session_state)
|
| 677 |
+
|
| 678 |
+
if df is None:
|
| 679 |
+
return status, None, "", None, None, {}, output_filename
|
| 680 |
+
|
| 681 |
+
# Generate visualizations
|
| 682 |
+
viz_files = {}
|
| 683 |
+
if visualizer:
|
| 684 |
+
try:
|
| 685 |
+
viz_dir = "reports/visualizations"
|
| 686 |
+
os.makedirs(viz_dir, exist_ok=True)
|
| 687 |
+
viz_files = visualizer.create_all_charts(df, output_dir=viz_dir)
|
| 688 |
+
except Exception as e:
|
| 689 |
+
logger.warning(f"Visualization generation failed: {e}")
|
| 690 |
+
|
| 691 |
+
# Generate report
|
| 692 |
+
report_paths = {}
|
| 693 |
+
if report_generator:
|
| 694 |
+
try:
|
| 695 |
+
report_paths = report_generator.generate_report(
|
| 696 |
+
df, summary,
|
| 697 |
+
report_name=output_filename.replace('.csv', ''),
|
| 698 |
+
visualizations=viz_files,
|
| 699 |
+
raw_results=raw_results
|
| 700 |
+
)
|
| 701 |
+
except Exception as e:
|
| 702 |
+
logger.warning(f"Report generation failed: {e}")
|
| 703 |
+
|
| 704 |
+
# Create summary text
|
| 705 |
+
summary_text = f"### Evaluation Summary\n\n"
|
| 706 |
+
summary_text += f"**Total Queries**: {summary.get('total_queries', len(df['query'].unique())) if df is not None else 0}\n\n"
|
| 707 |
+
|
| 708 |
+
if summary and 'summary_stats' in summary:
|
| 709 |
+
summary_text += "**Average Metrics by Pipeline**:\n\n"
|
| 710 |
+
# Group by pipeline to show all k values together
|
| 711 |
+
pipelines = {}
|
| 712 |
+
for stat in summary['summary_stats']:
|
| 713 |
+
pipeline = stat.get('pipeline', 'unknown')
|
| 714 |
+
if pipeline not in pipelines:
|
| 715 |
+
pipelines[pipeline] = []
|
| 716 |
+
pipelines[pipeline].append(stat)
|
| 717 |
+
|
| 718 |
+
# Show all pipelines, sorted by name
|
| 719 |
+
for pipeline in sorted(pipelines.keys()):
|
| 720 |
+
for stat in sorted(pipelines[pipeline], key=lambda x: x.get('k', 0)):
|
| 721 |
+
k = stat.get('k', 'N/A')
|
| 722 |
+
summary_text += f"- {pipeline} (k={k}): "
|
| 723 |
+
summary_text += f"Precision@{k}={stat.get('precision_at_k', 0):.3f}, "
|
| 724 |
+
summary_text += f"nDCG@{k}={stat.get('ndcg_at_k', 0):.3f}, "
|
| 725 |
+
summary_text += f"MRR={stat.get('mrr', 0):.3f}\n"
|
| 726 |
+
|
| 727 |
+
if report_paths:
|
| 728 |
+
summary_text += f"\n**Reports Generated**:\n"
|
| 729 |
+
for report_type, path in report_paths.items():
|
| 730 |
+
summary_text += f"- {report_type}: {path}\n"
|
| 731 |
+
|
| 732 |
+
# Get visualization files as images - ensure paths exist and are absolute
|
| 733 |
+
bar_chart = None
|
| 734 |
+
line_plot = None
|
| 735 |
+
if viz_files:
|
| 736 |
+
bar_path = viz_files.get('bar')
|
| 737 |
+
line_path = viz_files.get('line')
|
| 738 |
+
# Check if files exist and convert to absolute paths
|
| 739 |
+
if bar_path and os.path.exists(bar_path):
|
| 740 |
+
bar_chart = os.path.abspath(bar_path)
|
| 741 |
+
if line_path and os.path.exists(line_path):
|
| 742 |
+
line_plot = os.path.abspath(line_path)
|
| 743 |
+
|
| 744 |
+
return status, df, summary_text, bar_chart, line_plot, viz_files, output_filename
|
| 745 |
+
|
| 746 |
+
def update_remaining_vizs(viz_files: Dict[str, str]):
|
| 747 |
+
"""Update remaining visualization tabs"""
|
| 748 |
+
if not viz_files:
|
| 749 |
+
return None, None, None, None
|
| 750 |
+
|
| 751 |
+
# Check if files exist and convert to absolute paths
|
| 752 |
+
scatter = viz_files.get('scatter')
|
| 753 |
+
box = viz_files.get('box')
|
| 754 |
+
stacked_bar = viz_files.get('stacked_bar')
|
| 755 |
+
pareto = viz_files.get('pareto')
|
| 756 |
+
|
| 757 |
+
scatter_path = os.path.abspath(scatter) if scatter and os.path.exists(scatter) else None
|
| 758 |
+
box_path = os.path.abspath(box) if box and os.path.exists(box) else None
|
| 759 |
+
stacked_bar_path = os.path.abspath(stacked_bar) if stacked_bar and os.path.exists(stacked_bar) else None
|
| 760 |
+
pareto_path = os.path.abspath(pareto) if pareto and os.path.exists(pareto) else None
|
| 761 |
+
|
| 762 |
+
return scatter_path, box_path, stacked_bar_path, pareto_path
|
| 763 |
+
|
| 764 |
+
# Diagnostics: simple OpenAI connectivity test
|
| 765 |
+
## (removed) test_openai_connectivity helper
|
| 766 |
+
|
| 767 |
+
# Initialize system
|
| 768 |
+
initialize_system()
|
| 769 |
+
|
| 770 |
+
# Create Gradio interface
|
| 771 |
+
# Minimal CSS to keep layout stable when vertical scrollbar appears and improve mobile spacing
|
| 772 |
+
APP_CSS = """
|
| 773 |
+
html, body { scrollbar-gutter: stable both-edges; }
|
| 774 |
+
body { overflow-y: scroll; }
|
| 775 |
+
* { box-sizing: border-box; }
|
| 776 |
+
@media (max-width: 768px) {
|
| 777 |
+
.gradio-container { padding-left: 8px; padding-right: 8px; }
|
| 778 |
+
}
|
| 779 |
+
"""
|
| 780 |
+
|
| 781 |
+
with gr.Blocks(title="Auto Tagging RAG", css=APP_CSS) as demo:
|
| 782 |
+
# Global header with session and status indicators
|
| 783 |
+
with gr.Row():
|
| 784 |
+
gr.Markdown("# Auto Tagging RAG System")
|
| 785 |
+
with gr.Row():
|
| 786 |
+
session_indicator = gr.Markdown("**Session**: Not initialized", visible=True)
|
| 787 |
+
document_count_indicator = gr.Markdown("**Documents**: 0", visible=True)
|
| 788 |
+
|
| 789 |
+
# Session state
|
| 790 |
+
session_state = gr.State(value={"session_id": None, "collection_name": None})
|
| 791 |
+
|
| 792 |
+
# BrowserState to persist session ID in localStorage
|
| 793 |
+
browser_session_id = gr.BrowserState(default_value=None, storage_key="rag_session_id")
|
| 794 |
+
|
| 795 |
+
with gr.Tab("Upload & Tagging"):
|
| 796 |
+
gr.Markdown("## Upload and Process Documents")
|
| 797 |
+
with gr.Row():
|
| 798 |
+
with gr.Column():
|
| 799 |
+
file_upload = gr.File(
|
| 800 |
+
label="Upload PDF/TXT Files",
|
| 801 |
+
file_count="multiple",
|
| 802 |
+
file_types=[".pdf", ".txt"]
|
| 803 |
+
)
|
| 804 |
+
language_dropdown = gr.Dropdown(
|
| 805 |
+
choices=["Auto", "en", "ja"],
|
| 806 |
+
label="Language",
|
| 807 |
+
value="Auto"
|
| 808 |
+
)
|
| 809 |
+
manual_tags_input = gr.Textbox(
|
| 810 |
+
label="Add Tags (comma-separated, optional)",
|
| 811 |
+
placeholder="hospital-protocol, urgent, confidential",
|
| 812 |
+
info="Add custom tags that will be combined with auto-generated tags",
|
| 813 |
+
lines=2
|
| 814 |
+
)
|
| 815 |
+
build_btn = gr.Button("Build RAG Index", variant="primary")
|
| 816 |
+
|
| 817 |
+
with gr.Column():
|
| 818 |
+
build_output = gr.Textbox(label="Build Status", lines=4)
|
| 819 |
+
stats_table = gr.DataFrame(label="Processing Summary")
|
| 820 |
+
chunks_table = gr.DataFrame(label="Indexed Chunks (preview)")
|
| 821 |
+
reset_btn = gr.Button("Reset Index (Clear chroma_data)", variant="secondary")
|
| 822 |
+
|
| 823 |
+
# Tag visualization section
|
| 824 |
+
with gr.Accordion("Tag Visualization", open=False):
|
| 825 |
+
tag_visualization = gr.Markdown(label="Generated Tags (Top 20)", value="Tags will appear here after processing...")
|
| 826 |
+
|
| 827 |
+
with gr.Tab("Search & Compare"):
|
| 828 |
+
gr.Markdown("## Compare All Retrieval Methods Side-by-Side")
|
| 829 |
+
with gr.Row():
|
| 830 |
+
with gr.Column():
|
| 831 |
+
search_query = gr.Textbox(
|
| 832 |
+
label="Search Query",
|
| 833 |
+
placeholder="Enter your query...",
|
| 834 |
+
lines=2
|
| 835 |
+
)
|
| 836 |
+
tags_input = gr.Textbox(
|
| 837 |
+
label="Tags (comma-separated, optional)",
|
| 838 |
+
placeholder="tag1, tag2, tag3",
|
| 839 |
+
lines=1
|
| 840 |
+
)
|
| 841 |
+
with gr.Row():
|
| 842 |
+
tag_operator = gr.Radio(
|
| 843 |
+
choices=["OR", "AND", "NOT"],
|
| 844 |
+
value="OR",
|
| 845 |
+
label="Tag Operator",
|
| 846 |
+
info="OR: any tag, AND: all tags, NOT: exclude tags"
|
| 847 |
+
)
|
| 848 |
+
k_slider = gr.Slider(
|
| 849 |
+
minimum=1, maximum=20,
|
| 850 |
+
value=int(os.getenv("DEFAULT_SEARCH_K", 5)),
|
| 851 |
+
step=1,
|
| 852 |
+
label="Number of results (k)"
|
| 853 |
+
)
|
| 854 |
+
with gr.Row():
|
| 855 |
+
vector_weight = gr.Slider(
|
| 856 |
+
minimum=0.0, maximum=1.0, value=0.7, step=0.1,
|
| 857 |
+
label="Vector Weight (for Hybrid)",
|
| 858 |
+
info="Weight for vector similarity in hybrid search"
|
| 859 |
+
)
|
| 860 |
+
tag_weight = gr.Slider(
|
| 861 |
+
minimum=0.0, maximum=1.0, value=0.3, step=0.1,
|
| 862 |
+
label="Tag Weight (for Hybrid)",
|
| 863 |
+
info="Weight for tag matching in hybrid search"
|
| 864 |
+
)
|
| 865 |
+
search_btn = gr.Button("Search All Methods", variant="primary")
|
| 866 |
+
search_status = gr.Markdown("**Status**: Ready - Click 'Search All Methods' to start")
|
| 867 |
+
|
| 868 |
+
with gr.Column():
|
| 869 |
+
with gr.Row():
|
| 870 |
+
base_results = gr.Textbox(
|
| 871 |
+
label="Base RAG",
|
| 872 |
+
lines=6,
|
| 873 |
+
max_lines=10,
|
| 874 |
+
value="Results will appear here..."
|
| 875 |
+
)
|
| 876 |
+
tag_results = gr.Textbox(
|
| 877 |
+
label="Tag Filter RAG",
|
| 878 |
+
lines=6,
|
| 879 |
+
max_lines=10,
|
| 880 |
+
value="Results will appear here..."
|
| 881 |
+
)
|
| 882 |
+
with gr.Row():
|
| 883 |
+
hybrid_results = gr.Textbox(
|
| 884 |
+
label="Hybrid RAG",
|
| 885 |
+
lines=6,
|
| 886 |
+
max_lines=10,
|
| 887 |
+
value="Results will appear here..."
|
| 888 |
+
)
|
| 889 |
+
rerank_results = gr.Textbox(
|
| 890 |
+
label="Hybrid Rerank RAG",
|
| 891 |
+
lines=6,
|
| 892 |
+
max_lines=10,
|
| 893 |
+
value="Results will appear here..."
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
search_summary = gr.Markdown()
|
| 897 |
+
|
| 898 |
+
with gr.Tab("Chat Interface"):
|
| 899 |
+
gr.Markdown("## Natural Conversation with Tag-Enhanced RAG")
|
| 900 |
+
with gr.Row():
|
| 901 |
+
with gr.Column(scale=1):
|
| 902 |
+
pipeline_radio = gr.Radio(
|
| 903 |
+
choices=["base_rag", "tag_filter_rag", "hybrid_rag", "hybrid_rerank_rag"],
|
| 904 |
+
label="RAG Pipeline",
|
| 905 |
+
value="hybrid_rag"
|
| 906 |
+
)
|
| 907 |
+
use_tags_toggle = gr.Checkbox(
|
| 908 |
+
label="Enable Tag Filtering",
|
| 909 |
+
value=False,
|
| 910 |
+
info="Use tags for tag-based pipelines"
|
| 911 |
+
)
|
| 912 |
+
chat_tags_input = gr.Textbox(
|
| 913 |
+
label="Tags (comma-separated, for tag-based pipelines)",
|
| 914 |
+
placeholder="tag1, tag2",
|
| 915 |
+
lines=1,
|
| 916 |
+
visible=True
|
| 917 |
+
)
|
| 918 |
+
chat_k_slider = gr.Slider(
|
| 919 |
+
minimum=1, maximum=10, value=3, step=1,
|
| 920 |
+
label="Number of results"
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
with gr.Column(scale=2):
|
| 924 |
+
chatbot = gr.Chatbot(label="RAG Chat", type="messages", height=400)
|
| 925 |
+
chat_input = gr.Textbox(
|
| 926 |
+
label="Message",
|
| 927 |
+
placeholder="Ask a question...",
|
| 928 |
+
lines=2
|
| 929 |
+
)
|
| 930 |
+
chat_btn = gr.Button("Send", variant="primary")
|
| 931 |
+
|
| 932 |
+
# Source documents display with tags
|
| 933 |
+
with gr.Accordion("📎 Source Documents", open=False):
|
| 934 |
+
chat_sources = gr.Markdown(label="Retrieved Documents with Tags", value="Source documents will appear here after chat...")
|
| 935 |
+
|
| 936 |
+
with gr.Tab("Analytics & Evaluation"):
|
| 937 |
+
gr.Markdown("## Performance Visualization and Metrics")
|
| 938 |
+
with gr.Row():
|
| 939 |
+
with gr.Column():
|
| 940 |
+
eval_queries = gr.Textbox(
|
| 941 |
+
label="Evaluation Queries (JSON)",
|
| 942 |
+
lines=8,
|
| 943 |
+
placeholder='''[
|
| 944 |
+
{
|
| 945 |
+
"query": "What are emergency procedures?",
|
| 946 |
+
"ground_truth": ["chunk_id_1", "chunk_id_2"],
|
| 947 |
+
"k_values": [1, 3, 5],
|
| 948 |
+
"tags": ["emergency", "procedure"]
|
| 949 |
+
}
|
| 950 |
+
]''',
|
| 951 |
+
value='''[
|
| 952 |
+
{
|
| 953 |
+
"query": "What are the emergency procedures?",
|
| 954 |
+
"ground_truth": ["Emergency protocols for triage", "Patient assessment guidelines"],
|
| 955 |
+
"k_values": [1, 3, 5]
|
| 956 |
+
}
|
| 957 |
+
]'''
|
| 958 |
+
)
|
| 959 |
+
user_satisfaction_input = gr.Textbox(
|
| 960 |
+
label="User Satisfaction Scores (JSON, optional)",
|
| 961 |
+
lines=3,
|
| 962 |
+
placeholder='''{
|
| 963 |
+
"query_0": 4.5,
|
| 964 |
+
"query_1": 4.2,
|
| 965 |
+
"query_2": 4.8
|
| 966 |
+
}''',
|
| 967 |
+
value=""
|
| 968 |
+
)
|
| 969 |
+
eval_output_name = gr.Textbox(
|
| 970 |
+
label="Output Filename",
|
| 971 |
+
value="evaluation_results.csv"
|
| 972 |
+
)
|
| 973 |
+
with gr.Row():
|
| 974 |
+
eval_btn = gr.Button("Run Evaluation", variant="primary")
|
| 975 |
+
export_csv_btn = gr.Button("Download CSV", variant="secondary")
|
| 976 |
+
export_png_btn = gr.Button("Download Charts", variant="secondary")
|
| 977 |
+
|
| 978 |
+
with gr.Column():
|
| 979 |
+
eval_output = gr.Textbox(label="Evaluation Status", lines=3)
|
| 980 |
+
eval_summary_text = gr.Markdown(label="Summary Statistics")
|
| 981 |
+
eval_results_table = gr.DataFrame(label="Evaluation Results")
|
| 982 |
+
query_history = gr.DataFrame(label="Query History with Performance Scores", visible=False)
|
| 983 |
+
export_csv_file = gr.File(visible=False)
|
| 984 |
+
export_charts_files = gr.File(visible=False)
|
| 985 |
+
|
| 986 |
+
with gr.Tabs():
|
| 987 |
+
with gr.Tab("Bar Charts"):
|
| 988 |
+
eval_bar_chart = gr.Image(label="Metric Comparison (Bar Chart)")
|
| 989 |
+
with gr.Tab("Line Plots"):
|
| 990 |
+
eval_line_plot = gr.Image(label="Metric Trends (Line Plot)")
|
| 991 |
+
with gr.Tab("Scatter Plots"):
|
| 992 |
+
eval_scatter_plot = gr.Image(label="Correlation Analysis (Scatter Plot)")
|
| 993 |
+
with gr.Tab("Box Plots"):
|
| 994 |
+
eval_box_plot = gr.Image(label="Distribution Analysis (Box Plot)")
|
| 995 |
+
with gr.Tab("Stacked Bar"):
|
| 996 |
+
eval_stacked_plot = gr.Image(label="Method Breakdown (Stacked Bar)")
|
| 997 |
+
with gr.Tab("Pareto"):
|
| 998 |
+
eval_pareto_plot = gr.Image(label="Performance Ranking (Pareto Chart)")
|
| 999 |
+
with gr.Tab("Settings & Management"):
|
| 1000 |
+
gr.Markdown("## System Configuration and User Management")
|
| 1001 |
+
|
| 1002 |
+
with gr.Accordion("Tag Generation Parameters", open=False):
|
| 1003 |
+
max_tags_slider = gr.Slider(
|
| 1004 |
+
minimum=5, maximum=50, value=10, step=1,
|
| 1005 |
+
label="Max Tags Per Chunk",
|
| 1006 |
+
info="Maximum number of tags to generate per document chunk"
|
| 1007 |
+
)
|
| 1008 |
+
min_tag_length_slider = gr.Slider(
|
| 1009 |
+
minimum=1, maximum=5, value=2, step=1,
|
| 1010 |
+
label="Min Tag Length (words)",
|
| 1011 |
+
info="Minimum number of words in a tag phrase"
|
| 1012 |
+
)
|
| 1013 |
+
max_tag_length_slider = gr.Slider(
|
| 1014 |
+
minimum=1, maximum=5, value=3, step=1,
|
| 1015 |
+
label="Max Tag Length (words)",
|
| 1016 |
+
info="Maximum number of words in a tag phrase"
|
| 1017 |
+
)
|
| 1018 |
+
tag_method_dropdown = gr.Dropdown(
|
| 1019 |
+
choices=["auto", "yake", "keybert", "spacy", "janome", "openai"],
|
| 1020 |
+
value="auto",
|
| 1021 |
+
label="Tag Generation Method",
|
| 1022 |
+
info="Method for generating tags (auto selects best available)"
|
| 1023 |
+
)
|
| 1024 |
+
apply_tag_params_btn = gr.Button("Apply Tag Settings", variant="primary")
|
| 1025 |
+
tag_params_status = gr.Textbox(label="Status", lines=2, interactive=False)
|
| 1026 |
+
|
| 1027 |
+
with gr.Accordion("Hybrid Search Weights", open=False):
|
| 1028 |
+
default_vector_weight = gr.Slider(
|
| 1029 |
+
minimum=0.0, maximum=1.0, value=0.7, step=0.1,
|
| 1030 |
+
label="Default Vector Weight"
|
| 1031 |
+
)
|
| 1032 |
+
default_tag_weight = gr.Slider(
|
| 1033 |
+
minimum=0.0, maximum=1.0, value=0.3, step=0.1,
|
| 1034 |
+
label="Default Tag Weight"
|
| 1035 |
+
)
|
| 1036 |
+
apply_weights_btn = gr.Button("Apply Weight Settings", variant="primary")
|
| 1037 |
+
weights_status = gr.Textbox(label="Status", lines=2, interactive=False)
|
| 1038 |
+
|
| 1039 |
+
with gr.Accordion("Database Management", open=False):
|
| 1040 |
+
with gr.Row():
|
| 1041 |
+
clear_data_btn = gr.Button("Clear All Data", variant="stop")
|
| 1042 |
+
export_data_btn = gr.Button("Export Database", variant="secondary")
|
| 1043 |
+
import_data_btn = gr.File(label="Import Database", file_count="single", file_types=[".sqlite3", ".db"])
|
| 1044 |
+
db_status = gr.Textbox(label="Database Status", lines=2, interactive=False)
|
| 1045 |
+
|
| 1046 |
+
with gr.Accordion("Embedding Configuration", open=True):
|
| 1047 |
+
gr.Markdown("**Select the embedding provider and model.** Switching providers requires re-indexing your documents.")
|
| 1048 |
+
gr.Markdown("**Note:** For OpenAI embeddings, set `OPENAI_API_KEY` in your `.env` file or environment variables. API keys should not be set through the UI.")
|
| 1049 |
+
|
| 1050 |
+
with gr.Row():
|
| 1051 |
+
with gr.Column():
|
| 1052 |
+
emb_provider = gr.Radio(
|
| 1053 |
+
choices=["SentenceTransformers", "OpenAI"],
|
| 1054 |
+
value="SentenceTransformers",
|
| 1055 |
+
label="Embeddings Provider",
|
| 1056 |
+
info="Choose between local SentenceTransformers models or OpenAI embeddings (requires OPENAI_API_KEY in .env)"
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
with gr.Row():
|
| 1060 |
+
apply_embed_btn = gr.Button("Apply Embedding Settings", variant="primary")
|
| 1061 |
+
|
| 1062 |
+
with gr.Row():
|
| 1063 |
+
with gr.Column():
|
| 1064 |
+
st_model_in = gr.Textbox(
|
| 1065 |
+
label="SentenceTransformers Model",
|
| 1066 |
+
value=os.getenv("ST_EMBED_MODEL", "all-MiniLM-L6-v2"),
|
| 1067 |
+
interactive=False,
|
| 1068 |
+
info="Local embedding model (384 dimensions)"
|
| 1069 |
+
)
|
| 1070 |
+
with gr.Column():
|
| 1071 |
+
oai_model_in = gr.Textbox(
|
| 1072 |
+
label="OpenAI Embedding Model",
|
| 1073 |
+
value=os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small"),
|
| 1074 |
+
interactive=False,
|
| 1075 |
+
info="OpenAI embedding model (1536 dimensions for small, 3072 for large)"
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
embed_status = gr.Textbox(
|
| 1079 |
+
label="Status",
|
| 1080 |
+
lines=3,
|
| 1081 |
+
interactive=False,
|
| 1082 |
+
placeholder="Embedding configuration status will appear here..."
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
# Define handler before wiring it
|
| 1086 |
+
def _apply_embeddings(provider, st_model, oai_model):
|
| 1087 |
+
try:
|
| 1088 |
+
use_oai = (provider == "OpenAI")
|
| 1089 |
+
rag_manager.vector_store.configure_embeddings(use_oai, openai_model=oai_model, st_model_name=st_model)
|
| 1090 |
+
status_msg = f"✅ Embeddings successfully configured!\n\n"
|
| 1091 |
+
status_msg += f"Provider: {provider}\n"
|
| 1092 |
+
if use_oai:
|
| 1093 |
+
status_msg += f"Model: {oai_model} (OpenAI)\n"
|
| 1094 |
+
status_msg += f"Dimensions: {3072 if 'large' in oai_model.lower() else 1536}\n"
|
| 1095 |
+
else:
|
| 1096 |
+
status_msg += f"Model: {st_model} (SentenceTransformers)\n"
|
| 1097 |
+
status_msg += f"Dimensions: ~384\n"
|
| 1098 |
+
status_msg += f"\n⚠️ Note: If switching providers, reset and rebuild your index in the Upload tab."
|
| 1099 |
+
return status_msg
|
| 1100 |
+
except Exception as ex:
|
| 1101 |
+
return f"❌ Failed to set embeddings: {ex}\n\nPlease check your configuration and try again."
|
| 1102 |
+
|
| 1103 |
+
# Handler functions for Settings tab
|
| 1104 |
+
def apply_tag_params(max_tags, min_len, max_len, method):
|
| 1105 |
+
"""Apply tag generation parameters"""
|
| 1106 |
+
global rag_manager
|
| 1107 |
+
if not rag_manager:
|
| 1108 |
+
return "System not initialized. Please initialize first."
|
| 1109 |
+
# Store settings in environment or config (simplified for now)
|
| 1110 |
+
os.environ['MAX_TAGS_PER_CHUNK'] = str(max_tags)
|
| 1111 |
+
os.environ['MIN_TAG_LENGTH'] = str(min_len)
|
| 1112 |
+
os.environ['MAX_TAG_LENGTH'] = str(max_len)
|
| 1113 |
+
os.environ['TAG_GENERATION_METHOD'] = method
|
| 1114 |
+
return f"✅ Tag parameters updated:\n- Max tags: {max_tags}\n- Tag length: {min_len}-{max_len} words\n- Method: {method}"
|
| 1115 |
+
|
| 1116 |
+
def apply_weight_settings(vec_weight, tag_weight):
|
| 1117 |
+
"""Apply default hybrid search weights"""
|
| 1118 |
+
os.environ['DEFAULT_VECTOR_WEIGHT'] = str(vec_weight)
|
| 1119 |
+
os.environ['DEFAULT_TAG_WEIGHT'] = str(tag_weight)
|
| 1120 |
+
return f"✅ Default weights updated:\n- Vector: {vec_weight}\n- Tag: {tag_weight}"
|
| 1121 |
+
|
| 1122 |
+
def clear_all_data():
|
| 1123 |
+
"""Clear all database data"""
|
| 1124 |
+
try:
|
| 1125 |
+
reset_index()
|
| 1126 |
+
return "✅ All data cleared successfully. Please rebuild your index."
|
| 1127 |
+
except Exception as e:
|
| 1128 |
+
return f"❌ Error clearing data: {str(e)}"
|
| 1129 |
+
|
| 1130 |
+
def export_database():
|
| 1131 |
+
"""Export database"""
|
| 1132 |
+
try:
|
| 1133 |
+
db_path = persist_directory or "./chroma_data"
|
| 1134 |
+
if os.path.exists(db_path):
|
| 1135 |
+
import shutil
|
| 1136 |
+
export_path = f"export_{int(time.time())}.tar.gz"
|
| 1137 |
+
shutil.make_archive(export_path.replace('.tar.gz', ''), 'gztar', db_path)
|
| 1138 |
+
return f"✅ Database exported to: {export_path}"
|
| 1139 |
+
return "❌ No database found to export"
|
| 1140 |
+
except Exception as e:
|
| 1141 |
+
return f"❌ Export failed: {str(e)}"
|
| 1142 |
+
|
| 1143 |
+
apply_tag_params_btn.click(
|
| 1144 |
+
fn=apply_tag_params,
|
| 1145 |
+
inputs=[max_tags_slider, min_tag_length_slider, max_tag_length_slider, tag_method_dropdown],
|
| 1146 |
+
outputs=[tag_params_status]
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
apply_weights_btn.click(
|
| 1150 |
+
fn=apply_weight_settings,
|
| 1151 |
+
inputs=[default_vector_weight, default_tag_weight],
|
| 1152 |
+
outputs=[weights_status]
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
clear_data_btn.click(
|
| 1156 |
+
fn=clear_all_data,
|
| 1157 |
+
inputs=None,
|
| 1158 |
+
outputs=[db_status]
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
export_data_btn.click(
|
| 1162 |
+
fn=export_database,
|
| 1163 |
+
inputs=None,
|
| 1164 |
+
outputs=[db_status]
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
apply_embed_btn.click(
|
| 1168 |
+
fn=_apply_embeddings,
|
| 1169 |
+
inputs=[emb_provider, st_model_in, oai_model_in],
|
| 1170 |
+
outputs=embed_status
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
# Initialize session on load with localStorage support
|
| 1174 |
+
def init_session_on_load(browser_sid):
|
| 1175 |
+
"""Initialize session from localStorage or create new one"""
|
| 1176 |
+
# Read session ID from BrowserState (localStorage)
|
| 1177 |
+
# browser_sid could be None, str, or already be a session ID
|
| 1178 |
+
if browser_sid and isinstance(browser_sid, str):
|
| 1179 |
+
session_id = browser_sid.strip() if browser_sid.strip() else None
|
| 1180 |
+
else:
|
| 1181 |
+
session_id = None
|
| 1182 |
+
|
| 1183 |
+
result = init_session(session_id)
|
| 1184 |
+
|
| 1185 |
+
# Return session_state, the session_id (BrowserState will auto-save to localStorage), document count, and session indicator
|
| 1186 |
+
session_id_str = result.get("session_id", "")
|
| 1187 |
+
doc_count = get_document_count(result)
|
| 1188 |
+
doc_count_str = f"**Documents**: {doc_count}"
|
| 1189 |
+
session_indicator_str = f"**Session**: {session_id_str[:8]}..." if session_id_str else "**Session**: Not initialized"
|
| 1190 |
+
logger.info(f"Session initialized: {session_id_str[:8]}... (from localStorage: {session_id is not None})")
|
| 1191 |
+
return result, session_id_str, doc_count_str, session_indicator_str
|
| 1192 |
+
|
| 1193 |
+
# Read session ID from localStorage on page load and initialize session
|
| 1194 |
+
demo.load(
|
| 1195 |
+
fn=init_session_on_load,
|
| 1196 |
+
inputs=[browser_session_id],
|
| 1197 |
+
outputs=[session_state, browser_session_id, document_count_indicator, session_indicator],
|
| 1198 |
+
queue=False
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
# Save session ID to BrowserState (localStorage) whenever session_state is updated
|
| 1202 |
+
def update_browser_session(session_data: Dict[str, Any]) -> Tuple[Dict[str, Any], str, str, str]:
|
| 1203 |
+
"""Update session state, save to BrowserState (localStorage), and update document count and session indicator"""
|
| 1204 |
+
session_id = session_data.get("session_id", "") if session_data else ""
|
| 1205 |
+
doc_count = get_document_count(session_data)
|
| 1206 |
+
doc_count_str = f"**Documents**: {doc_count}"
|
| 1207 |
+
session_indicator_str = f"**Session**: {session_id[:8]}..." if session_id else "**Session**: Not initialized"
|
| 1208 |
+
return session_data, session_id, doc_count_str, session_indicator_str
|
| 1209 |
+
|
| 1210 |
+
# Hook into session_state changes to save to BrowserState (localStorage) and update document count
|
| 1211 |
+
session_state.change(
|
| 1212 |
+
fn=update_browser_session,
|
| 1213 |
+
inputs=[session_state],
|
| 1214 |
+
outputs=[session_state, browser_session_id, document_count_indicator, session_indicator],
|
| 1215 |
+
queue=False
|
| 1216 |
+
)
|
| 1217 |
+
|
| 1218 |
+
# Event handlers
|
| 1219 |
+
build_btn.click(
|
| 1220 |
+
fn=build_with_session,
|
| 1221 |
+
inputs=[file_upload, language_dropdown, manual_tags_input, session_state],
|
| 1222 |
+
outputs=[session_state, build_output, stats_table, chunks_table, tag_visualization],
|
| 1223 |
+
api_name="build_rag"
|
| 1224 |
+
).then(
|
| 1225 |
+
fn=lambda s: gr.update(value=f"**Session**: {s.get('session_id', 'Unknown')[:8]}..." if s and s.get('session_id') else "**Session**: Not initialized"),
|
| 1226 |
+
inputs=[session_state],
|
| 1227 |
+
outputs=[session_indicator],
|
| 1228 |
+
queue=False
|
| 1229 |
+
).then(
|
| 1230 |
+
fn=lambda s: gr.update(value=f"**Documents**: {get_document_count(s)}"),
|
| 1231 |
+
inputs=[session_state],
|
| 1232 |
+
outputs=[document_count_indicator],
|
| 1233 |
+
queue=False
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
def reset_index_with_count(session_state: Dict[str, Any]) -> Tuple[str, str]:
|
| 1237 |
+
"""Reset index and return updated document count"""
|
| 1238 |
+
reset_msg = reset_index()
|
| 1239 |
+
# After reset, document count should be 0
|
| 1240 |
+
return reset_msg, "**Documents**: 0"
|
| 1241 |
+
|
| 1242 |
+
reset_btn.click(
|
| 1243 |
+
fn=reset_index_with_count,
|
| 1244 |
+
inputs=[session_state],
|
| 1245 |
+
outputs=[build_output, document_count_indicator]
|
| 1246 |
+
)
|
| 1247 |
+
|
| 1248 |
+
# Search all methods - process sequentially with status updates
|
| 1249 |
+
def search_with_status(query: str, k: int, tags: str, tag_operator: str, vector_weight: float, tag_weight: float, session_state: Dict[str, Any]):
|
| 1250 |
+
"""Wrapper to update status during search"""
|
| 1251 |
+
return search_all_methods(query, k, tags, tag_operator, vector_weight, tag_weight, session_state)
|
| 1252 |
+
|
| 1253 |
+
search_btn.click(
|
| 1254 |
+
fn=lambda: gr.update(value="**Status**: 🔄 Processing Base RAG..."),
|
| 1255 |
+
outputs=[search_status],
|
| 1256 |
+
queue=False
|
| 1257 |
+
).then(
|
| 1258 |
+
fn=search_all_methods,
|
| 1259 |
+
inputs=[search_query, k_slider, tags_input, tag_operator, vector_weight, tag_weight, session_state],
|
| 1260 |
+
outputs=[base_results, tag_results, hybrid_results, rerank_results, search_summary, session_state],
|
| 1261 |
+
api_name="search_all"
|
| 1262 |
+
).then(
|
| 1263 |
+
fn=lambda: gr.update(value="**Status**: ✅ All methods completed!"),
|
| 1264 |
+
outputs=[search_status],
|
| 1265 |
+
queue=False
|
| 1266 |
+
)
|
| 1267 |
+
|
| 1268 |
+
# Toggle tag input visibility
|
| 1269 |
+
def toggle_tag_input(use_tags):
|
| 1270 |
+
return gr.update(visible=use_tags)
|
| 1271 |
+
|
| 1272 |
+
use_tags_toggle.change(
|
| 1273 |
+
fn=toggle_tag_input,
|
| 1274 |
+
inputs=[use_tags_toggle],
|
| 1275 |
+
outputs=[chat_tags_input]
|
| 1276 |
+
)
|
| 1277 |
+
|
| 1278 |
+
chat_btn.click(
|
| 1279 |
+
fn=chat_with_rag,
|
| 1280 |
+
inputs=[chat_input, chatbot, pipeline_radio, chat_k_slider, chat_tags_input, use_tags_toggle, session_state],
|
| 1281 |
+
outputs=[chat_input, chatbot, session_state, chat_sources],
|
| 1282 |
+
api_name="chat"
|
| 1283 |
+
).then(
|
| 1284 |
+
lambda: None,
|
| 1285 |
+
None,
|
| 1286 |
+
chat_input,
|
| 1287 |
+
queue=False
|
| 1288 |
+
)
|
| 1289 |
+
|
| 1290 |
+
eval_viz_state = gr.State(value={})
|
| 1291 |
+
eval_output_filename_state = gr.State(value="")
|
| 1292 |
+
|
| 1293 |
+
eval_btn.click(
|
| 1294 |
+
fn=run_evaluation_with_viz,
|
| 1295 |
+
inputs=[eval_queries, eval_output_name, user_satisfaction_input, session_state],
|
| 1296 |
+
outputs=[eval_output, eval_results_table, eval_summary_text, eval_bar_chart, eval_line_plot, eval_viz_state, eval_output_filename_state],
|
| 1297 |
+
api_name="evaluate"
|
| 1298 |
+
).then(
|
| 1299 |
+
fn=update_remaining_vizs,
|
| 1300 |
+
inputs=[eval_viz_state],
|
| 1301 |
+
outputs=[eval_scatter_plot, eval_box_plot, eval_stacked_plot, eval_pareto_plot],
|
| 1302 |
+
queue=False
|
| 1303 |
+
)
|
| 1304 |
+
|
| 1305 |
+
# Export button handlers
|
| 1306 |
+
def export_csv_wrapper(filename_state):
|
| 1307 |
+
"""Export CSV from current evaluation results - returns file for download"""
|
| 1308 |
+
try:
|
| 1309 |
+
# Use stored filename from evaluation if available, otherwise use provided filename
|
| 1310 |
+
filename = filename_state if filename_state else "evaluation_results.csv"
|
| 1311 |
+
|
| 1312 |
+
# Ensure .csv extension
|
| 1313 |
+
if not filename.endswith('.csv'):
|
| 1314 |
+
filename = f"{filename}.csv"
|
| 1315 |
+
|
| 1316 |
+
csv_path = os.path.join("reports", filename)
|
| 1317 |
+
|
| 1318 |
+
if os.path.exists(csv_path):
|
| 1319 |
+
# Return absolute path for download
|
| 1320 |
+
return os.path.abspath(csv_path)
|
| 1321 |
+
else:
|
| 1322 |
+
# Try to find any CSV file in reports directory
|
| 1323 |
+
reports_dir = "reports"
|
| 1324 |
+
if os.path.exists(reports_dir):
|
| 1325 |
+
csv_files = [f for f in os.listdir(reports_dir) if f.endswith('.csv')]
|
| 1326 |
+
if csv_files:
|
| 1327 |
+
# Return the most recent one
|
| 1328 |
+
csv_files.sort(key=lambda x: os.path.getmtime(os.path.join(reports_dir, x)), reverse=True)
|
| 1329 |
+
return os.path.abspath(os.path.join(reports_dir, csv_files[0]))
|
| 1330 |
+
|
| 1331 |
+
return None
|
| 1332 |
+
except Exception as e:
|
| 1333 |
+
logger.error(f"CSV export error: {e}")
|
| 1334 |
+
return None
|
| 1335 |
+
|
| 1336 |
+
def export_png_wrapper():
|
| 1337 |
+
"""Export PNG charts - creates a ZIP file for download"""
|
| 1338 |
+
try:
|
| 1339 |
+
viz_dir = "reports/visualizations"
|
| 1340 |
+
if os.path.exists(viz_dir):
|
| 1341 |
+
png_files = [f for f in os.listdir(viz_dir) if f.endswith('.png')]
|
| 1342 |
+
if png_files:
|
| 1343 |
+
# Create a ZIP file containing all PNG charts
|
| 1344 |
+
zip_path = os.path.join("reports", "charts.zip")
|
| 1345 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 1346 |
+
for png_file in png_files:
|
| 1347 |
+
file_path = os.path.join(viz_dir, png_file)
|
| 1348 |
+
# Add file to ZIP with just the filename (no directory structure)
|
| 1349 |
+
zipf.write(file_path, arcname=png_file)
|
| 1350 |
+
# Return absolute path to ZIP file
|
| 1351 |
+
return os.path.abspath(zip_path)
|
| 1352 |
+
return None
|
| 1353 |
+
except Exception as e:
|
| 1354 |
+
logger.error(f"Charts export error: {e}")
|
| 1355 |
+
return None
|
| 1356 |
+
|
| 1357 |
+
# JavaScript to directly trigger browser download from File component value
|
| 1358 |
+
# When File component gets a file path, Gradio creates a downloadable URL
|
| 1359 |
+
# We extract that URL and trigger download immediately
|
| 1360 |
+
auto_download_js = """
|
| 1361 |
+
function(fileValue) {
|
| 1362 |
+
// Wait a bit for Gradio to process the file and create download URL
|
| 1363 |
+
setTimeout(() => {
|
| 1364 |
+
if (fileValue) {
|
| 1365 |
+
// File component value can be a string (path) or object with file info
|
| 1366 |
+
let fileUrl = null;
|
| 1367 |
+
let filename = 'download';
|
| 1368 |
+
|
| 1369 |
+
if (typeof fileValue === 'string') {
|
| 1370 |
+
// If it's already a URL (Gradio file endpoint)
|
| 1371 |
+
if (fileValue.startsWith('http') || fileValue.startsWith('/')) {
|
| 1372 |
+
fileUrl = fileValue;
|
| 1373 |
+
filename = fileValue.split('/').pop().split('?')[0] || 'download';
|
| 1374 |
+
} else {
|
| 1375 |
+
// If it's a file path, construct Gradio file URL
|
| 1376 |
+
const baseUrl = window.location.origin + window.location.pathname.replace(/\/$/, '');
|
| 1377 |
+
fileUrl = baseUrl + '/file=' + encodeURIComponent(fileValue);
|
| 1378 |
+
filename = fileValue.split('/').pop() || 'download';
|
| 1379 |
+
}
|
| 1380 |
+
} else if (fileValue && fileValue.url) {
|
| 1381 |
+
// If it's an object with url property (Gradio FileData)
|
| 1382 |
+
fileUrl = fileValue.url;
|
| 1383 |
+
filename = fileValue.name || fileValue.url.split('/').pop().split('?')[0] || 'download';
|
| 1384 |
+
}
|
| 1385 |
+
|
| 1386 |
+
if (fileUrl) {
|
| 1387 |
+
// Create and trigger download
|
| 1388 |
+
const link = document.createElement('a');
|
| 1389 |
+
link.href = fileUrl;
|
| 1390 |
+
link.download = filename;
|
| 1391 |
+
link.style.display = 'none';
|
| 1392 |
+
document.body.appendChild(link);
|
| 1393 |
+
link.click();
|
| 1394 |
+
setTimeout(() => document.body.removeChild(link), 100);
|
| 1395 |
+
}
|
| 1396 |
+
}
|
| 1397 |
+
}, 300);
|
| 1398 |
+
return fileValue;
|
| 1399 |
+
}
|
| 1400 |
+
"""
|
| 1401 |
+
|
| 1402 |
+
export_csv_btn.click(
|
| 1403 |
+
fn=export_csv_wrapper,
|
| 1404 |
+
inputs=[eval_output_filename_state],
|
| 1405 |
+
outputs=[export_csv_file]
|
| 1406 |
+
).then(
|
| 1407 |
+
fn=None,
|
| 1408 |
+
inputs=[export_csv_file],
|
| 1409 |
+
outputs=None,
|
| 1410 |
+
js=auto_download_js
|
| 1411 |
+
)
|
| 1412 |
+
|
| 1413 |
+
export_png_btn.click(
|
| 1414 |
+
fn=export_png_wrapper,
|
| 1415 |
+
inputs=None,
|
| 1416 |
+
outputs=[export_charts_files]
|
| 1417 |
+
).then(
|
| 1418 |
+
fn=None,
|
| 1419 |
+
inputs=[export_charts_files],
|
| 1420 |
+
outputs=None,
|
| 1421 |
+
js=auto_download_js
|
| 1422 |
+
)
|
| 1423 |
+
|
| 1424 |
+
# Diagnostics trigger removed
|
| 1425 |
+
|
| 1426 |
+
# MCP Server Implementation
|
| 1427 |
+
import asyncio
|
| 1428 |
+
import sys
|
| 1429 |
+
from typing import Any, List, Optional
|
| 1430 |
+
try:
|
| 1431 |
+
from mcp.server import Server
|
| 1432 |
+
from mcp.server.models import InitializationOptions
|
| 1433 |
+
from mcp.types import Tool, TextContent
|
| 1434 |
+
MCP_AVAILABLE = True
|
| 1435 |
+
except ImportError:
|
| 1436 |
+
MCP_AVAILABLE = False
|
| 1437 |
+
# Fallback for when MCP is not installed
|
| 1438 |
+
Server = None
|
| 1439 |
+
Tool = None
|
| 1440 |
+
TextContent = None
|
| 1441 |
+
|
| 1442 |
+
class RAGMCPServer:
|
| 1443 |
+
"""MCP server for RAG system"""
|
| 1444 |
+
|
| 1445 |
+
def __init__(self):
|
| 1446 |
+
persist_dir = "/data/chroma" if os.path.exists("/data/chroma") else "./chroma_data"
|
| 1447 |
+
self.rag_manager = RAGManager(persist_directory=persist_dir)
|
| 1448 |
+
self.evaluator = RAGEvaluator(self.rag_manager)
|
| 1449 |
+
|
| 1450 |
+
async def list_tools(self) -> List[Tool]:
|
| 1451 |
+
"""List available MCP tools"""
|
| 1452 |
+
return [
|
| 1453 |
+
Tool(
|
| 1454 |
+
name="search_documents",
|
| 1455 |
+
description="Search documents using RAG system (Base-RAG, Tag Filter, Hybrid, or Hybrid Rerank)",
|
| 1456 |
+
inputSchema={
|
| 1457 |
+
"type": "object",
|
| 1458 |
+
"properties": {
|
| 1459 |
+
"query": {"type": "string", "description": "Search query"},
|
| 1460 |
+
"k": {"type": "integer", "description": "Number of results", "default": 5},
|
| 1461 |
+
"pipeline": {"type": "string", "enum": ["base_rag", "tag_filter_rag", "hybrid_rag", "hybrid_rerank_rag"], "default": "base_rag"},
|
| 1462 |
+
"tags": {"type": "array", "items": {"type": "string"}, "description": "Tags for tag-based search"},
|
| 1463 |
+
"tag_operator": {"type": "string", "enum": ["OR", "AND", "NOT"], "description": "Tag operator (OR/AND/NOT)", "default": "OR"},
|
| 1464 |
+
},
|
| 1465 |
+
"required": ["query"]
|
| 1466 |
+
}
|
| 1467 |
+
),
|
| 1468 |
+
Tool(
|
| 1469 |
+
name="evaluate_retrieval",
|
| 1470 |
+
description="Evaluate RAG performance with batch queries",
|
| 1471 |
+
inputSchema={
|
| 1472 |
+
"type": "object",
|
| 1473 |
+
"properties": {
|
| 1474 |
+
"queries": {
|
| 1475 |
+
"type": "array",
|
| 1476 |
+
"description": "List of query objects with query, ground_truth, k_values, and optional filters",
|
| 1477 |
+
"items": {"type": "object"}
|
| 1478 |
+
},
|
| 1479 |
+
"output_file": {"type": "string", "description": "Output filename for results"}
|
| 1480 |
+
},
|
| 1481 |
+
"required": ["queries"]
|
| 1482 |
+
}
|
| 1483 |
+
)
|
| 1484 |
+
]
|
| 1485 |
+
|
| 1486 |
+
async def call_tool(self, name: str, arguments: dict) -> List[TextContent]:
|
| 1487 |
+
"""Call an MCP tool by name"""
|
| 1488 |
+
if name == "search_documents":
|
| 1489 |
+
query = arguments.get("query")
|
| 1490 |
+
k = arguments.get("k", 5)
|
| 1491 |
+
pipeline = arguments.get("pipeline", "base_rag")
|
| 1492 |
+
tags = arguments.get("tags", [])
|
| 1493 |
+
tag_operator = arguments.get("tag_operator", "OR")
|
| 1494 |
+
|
| 1495 |
+
if pipeline == "base_rag":
|
| 1496 |
+
result = self.rag_manager.base_rag.retrieve(query, k)
|
| 1497 |
+
elif pipeline == "tag_filter_rag":
|
| 1498 |
+
result = self.rag_manager.tag_filter_rag.retrieve(query, k, tags=tags, tag_operator=tag_operator)
|
| 1499 |
+
elif pipeline == "hybrid_rag":
|
| 1500 |
+
result = self.rag_manager.hybrid_rag.retrieve(query, k, tags=tags, vector_weight=0.7, tag_weight=0.3)
|
| 1501 |
+
elif pipeline == "hybrid_rerank_rag":
|
| 1502 |
+
result = self.rag_manager.hybrid_rerank_rag.retrieve(query, k, tags=tags, vector_weight=0.7, tag_weight=0.3)
|
| 1503 |
+
else:
|
| 1504 |
+
result = self.rag_manager.base_rag.retrieve(query, k)
|
| 1505 |
+
|
| 1506 |
+
response = {
|
| 1507 |
+
"content": result.content,
|
| 1508 |
+
"sources": [
|
| 1509 |
+
{
|
| 1510 |
+
"content": source['content'][:200],
|
| 1511 |
+
"metadata": source['metadata'],
|
| 1512 |
+
"score": source['score']
|
| 1513 |
+
} for source in result.sources
|
| 1514 |
+
],
|
| 1515 |
+
"latency": result.latency,
|
| 1516 |
+
"strategy": pipeline
|
| 1517 |
+
}
|
| 1518 |
+
|
| 1519 |
+
return [TextContent(type="text", text=json.dumps(response, indent=2))]
|
| 1520 |
+
|
| 1521 |
+
elif name == "evaluate_retrieval":
|
| 1522 |
+
queries = arguments.get("queries", [])
|
| 1523 |
+
output_file = arguments.get("output_file")
|
| 1524 |
+
|
| 1525 |
+
df, results = self.evaluator.batch_evaluate(queries, output_file)
|
| 1526 |
+
|
| 1527 |
+
summary = df.groupby('pipeline').agg({
|
| 1528 |
+
'hit_at_k': 'mean',
|
| 1529 |
+
'mrr': 'mean',
|
| 1530 |
+
'semantic_similarity': 'mean',
|
| 1531 |
+
'latency': 'mean'
|
| 1532 |
+
}).reset_index()
|
| 1533 |
+
|
| 1534 |
+
response = {
|
| 1535 |
+
"summary": summary.to_dict('records'),
|
| 1536 |
+
"total_queries": len(queries),
|
| 1537 |
+
"output_file": output_file
|
| 1538 |
+
}
|
| 1539 |
+
|
| 1540 |
+
return [TextContent(type="text", text=json.dumps(response, indent=2))]
|
| 1541 |
+
|
| 1542 |
+
else:
|
| 1543 |
+
raise ValueError(f"Unknown tool: {name}")
|
| 1544 |
+
|
| 1545 |
+
# Export for Gradio Client
|
| 1546 |
+
if __name__ == "__main__":
|
| 1547 |
+
# If run as CLI, prefer plain Gradio serving. Spaces will import demo directly.
|
| 1548 |
+
# Respect common hosting env vars.
|
| 1549 |
+
host = os.getenv("HOST", "0.0.0.0")
|
| 1550 |
+
port = int(os.getenv("PORT", os.getenv("GRADIO_SERVER_PORT", 7860)))
|
| 1551 |
+
# Avoid SSR and API schema on Spaces to prevent response length errors
|
| 1552 |
+
demo.launch(server_name=host, server_port=port, share=False, ssl_verify=False, ssr_mode=False)
|
core/comparison.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comparison Framework for RAG Methods
|
| 3 |
+
|
| 4 |
+
This module provides utilities for comparing different RAG retrieval methods
|
| 5 |
+
side-by-side with aggregated metrics and easy-to-use comparison functions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 10 |
+
import numpy as np
|
| 11 |
+
from .eval import RAGEvaluator
|
| 12 |
+
from .retrieval import RetrievalResult
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RAGComparisonFramework:
|
| 16 |
+
"""Framework for comparing RAG retrieval methods side-by-side"""
|
| 17 |
+
|
| 18 |
+
# Method name mappings for display
|
| 19 |
+
METHOD_NAMES = {
|
| 20 |
+
'base_rag': 'Baseline',
|
| 21 |
+
'tag_filter_rag': '+Tags(Filter)',
|
| 22 |
+
'hybrid_rag': 'Hybrid(Weighted)',
|
| 23 |
+
'hybrid_rerank_rag': 'Hybrid+Rerank'
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def __init__(self, evaluator: RAGEvaluator):
|
| 27 |
+
"""
|
| 28 |
+
Initialize comparison framework.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
evaluator: RAGEvaluator instance for evaluation
|
| 32 |
+
"""
|
| 33 |
+
self.evaluator = evaluator
|
| 34 |
+
|
| 35 |
+
def compare_methods(self,
|
| 36 |
+
queries: List[Dict[str, Any]],
|
| 37 |
+
k_values: Optional[List[int]] = None,
|
| 38 |
+
methods: Optional[List[str]] = None) -> pd.DataFrame:
|
| 39 |
+
"""
|
| 40 |
+
Compare all methods side-by-side on given queries.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
queries: List of query dictionaries
|
| 44 |
+
k_values: List of k values to evaluate (default: [1, 3, 5])
|
| 45 |
+
methods: List of methods to compare (default: all 4 methods)
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
DataFrame with side-by-side comparison
|
| 49 |
+
"""
|
| 50 |
+
if k_values is None:
|
| 51 |
+
k_values = [1, 3, 5]
|
| 52 |
+
if methods is None:
|
| 53 |
+
methods = ['base_rag', 'tag_filter_rag', 'hybrid_rag', 'hybrid_rerank_rag']
|
| 54 |
+
|
| 55 |
+
# Run evaluation
|
| 56 |
+
df, summary, raw_results = self.evaluator.batch_evaluate(
|
| 57 |
+
queries=queries,
|
| 58 |
+
pipelines=methods
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return df
|
| 62 |
+
|
| 63 |
+
def get_comparison_table(self,
|
| 64 |
+
df: pd.DataFrame,
|
| 65 |
+
k_value: int,
|
| 66 |
+
metrics: Optional[List[str]] = None) -> pd.DataFrame:
|
| 67 |
+
"""
|
| 68 |
+
Get comparison table for specific k value.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
df: Evaluation results DataFrame
|
| 72 |
+
k_value: k value to filter by
|
| 73 |
+
metrics: List of metrics to include (default: all)
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Comparison table DataFrame
|
| 77 |
+
"""
|
| 78 |
+
if metrics is None:
|
| 79 |
+
metrics = ['precision_at_k', 'ndcg_at_k', 'mrr', 'hit_at_k', 'latency']
|
| 80 |
+
|
| 81 |
+
# Filter by k
|
| 82 |
+
k_df = df[df['k'] == k_value]
|
| 83 |
+
|
| 84 |
+
# Aggregate by pipeline
|
| 85 |
+
comparison = k_df.groupby('pipeline')[metrics].mean().reset_index()
|
| 86 |
+
|
| 87 |
+
# Rename pipelines
|
| 88 |
+
comparison['pipeline'] = comparison['pipeline'].map(
|
| 89 |
+
lambda x: self.METHOD_NAMES.get(x, x)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return comparison
|
| 93 |
+
|
| 94 |
+
def get_aggregated_comparison(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 95 |
+
"""
|
| 96 |
+
Get aggregated comparison across all k values.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
df: Evaluation results DataFrame
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Aggregated comparison DataFrame
|
| 103 |
+
"""
|
| 104 |
+
metrics = ['precision_at_k', 'ndcg_at_k', 'mrr', 'hit_at_k', 'latency']
|
| 105 |
+
|
| 106 |
+
# Aggregate by pipeline
|
| 107 |
+
aggregated = df.groupby('pipeline')[metrics].agg(['mean', 'std']).reset_index()
|
| 108 |
+
|
| 109 |
+
# Flatten column names
|
| 110 |
+
aggregated.columns = ['_'.join(col).strip('_') if col[1] else col[0]
|
| 111 |
+
for col in aggregated.columns.values]
|
| 112 |
+
|
| 113 |
+
# Rename pipelines
|
| 114 |
+
aggregated['pipeline'] = aggregated['pipeline'].map(
|
| 115 |
+
lambda x: self.METHOD_NAMES.get(x, x)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return aggregated
|
| 119 |
+
|
| 120 |
+
def get_method_rankings(self, df: pd.DataFrame, k_value: int,
|
| 121 |
+
metric: str = 'precision_at_k') -> pd.DataFrame:
|
| 122 |
+
"""
|
| 123 |
+
Get method rankings by metric.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
df: Evaluation results DataFrame
|
| 127 |
+
k_value: k value to filter by
|
| 128 |
+
metric: Metric to rank by
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Rankings DataFrame
|
| 132 |
+
"""
|
| 133 |
+
k_df = df[df['k'] == k_value]
|
| 134 |
+
|
| 135 |
+
# Average metric by pipeline
|
| 136 |
+
rankings = k_df.groupby('pipeline')[metric].mean().reset_index()
|
| 137 |
+
rankings = rankings.sort_values(metric, ascending=False)
|
| 138 |
+
|
| 139 |
+
# Add ranking
|
| 140 |
+
rankings['rank'] = range(1, len(rankings) + 1)
|
| 141 |
+
|
| 142 |
+
# Rename pipelines
|
| 143 |
+
rankings['pipeline'] = rankings['pipeline'].map(
|
| 144 |
+
lambda x: self.METHOD_NAMES.get(x, x)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
return rankings
|
| 148 |
+
|
core/eval.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import json
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
import math
|
| 9 |
+
from .retrieval import RAGManager, RetrievalResult
|
| 10 |
+
|
| 11 |
+
class RAGEvaluator:
|
| 12 |
+
"""Evaluation framework for RAG systems"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, rag_manager: RAGManager):
|
| 15 |
+
self.rag_manager = rag_manager
|
| 16 |
+
|
| 17 |
+
def evaluate_single_query(self, query: str, ground_truth: List[str],
|
| 18 |
+
k_values: List[int] = [1, 3, 5, 10],
|
| 19 |
+
level1: Optional[str] = None,
|
| 20 |
+
level2: Optional[str] = None,
|
| 21 |
+
level3: Optional[str] = None,
|
| 22 |
+
doc_type: Optional[str] = None) -> Dict[str, Any]:
|
| 23 |
+
"""Evaluate retrieval for a single query"""
|
| 24 |
+
|
| 25 |
+
base_results = {}
|
| 26 |
+
hier_results = {}
|
| 27 |
+
|
| 28 |
+
for k in k_values:
|
| 29 |
+
# Get results from both pipelines
|
| 30 |
+
base_result, hier_result = self.rag_manager.compare_retrieval(
|
| 31 |
+
query, k, level1, level2, level3, doc_type
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
base_results[k] = base_result
|
| 35 |
+
hier_results[k] = hier_result
|
| 36 |
+
|
| 37 |
+
# Calculate metrics
|
| 38 |
+
metrics = {
|
| 39 |
+
"query": query,
|
| 40 |
+
"ground_truth": ground_truth,
|
| 41 |
+
"base_rag": self._calculate_metrics(base_results, ground_truth),
|
| 42 |
+
"hier_rag": self._calculate_metrics(hier_results, ground_truth),
|
| 43 |
+
"filters": {
|
| 44 |
+
"level1": level1,
|
| 45 |
+
"level2": level2,
|
| 46 |
+
"level3": level3,
|
| 47 |
+
"doc_type": doc_type
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
return metrics
|
| 52 |
+
|
| 53 |
+
def _calculate_metrics(self, results_dict: Dict[int, RetrievalResult],
|
| 54 |
+
ground_truth: List[str]) -> Dict[str, Any]:
|
| 55 |
+
"""Calculate evaluation metrics including Precision@k and nDCG@k"""
|
| 56 |
+
metrics = {}
|
| 57 |
+
|
| 58 |
+
for k, result in results_dict.items():
|
| 59 |
+
retrieved_docs = [source['content'] for source in result.sources]
|
| 60 |
+
|
| 61 |
+
# Hit@k
|
| 62 |
+
hit_at_k = self._calculate_hit_at_k(retrieved_docs, ground_truth, k)
|
| 63 |
+
|
| 64 |
+
# Precision@k
|
| 65 |
+
precision_at_k = self._calculate_precision_at_k(retrieved_docs, ground_truth, k)
|
| 66 |
+
|
| 67 |
+
# nDCG@k
|
| 68 |
+
ndcg_at_k = self._calculate_ndcg_at_k(retrieved_docs, ground_truth, k)
|
| 69 |
+
|
| 70 |
+
# MRR
|
| 71 |
+
mrr = self._calculate_mrr(retrieved_docs, ground_truth)
|
| 72 |
+
|
| 73 |
+
# Semantic similarity
|
| 74 |
+
semantic_sim = self._calculate_semantic_similarity(retrieved_docs, ground_truth)
|
| 75 |
+
|
| 76 |
+
metrics[k] = {
|
| 77 |
+
"hit_at_k": hit_at_k,
|
| 78 |
+
"precision_at_k": precision_at_k,
|
| 79 |
+
"ndcg_at_k": ndcg_at_k,
|
| 80 |
+
"mrr": mrr,
|
| 81 |
+
"semantic_similarity": semantic_sim,
|
| 82 |
+
"latency": result.latency,
|
| 83 |
+
"retrieved_count": len(retrieved_docs)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
return metrics
|
| 87 |
+
|
| 88 |
+
def _calculate_precision_at_k(self, retrieved: List[str], ground_truth: List[str], k: int) -> float:
|
| 89 |
+
"""Calculate Precision@k metric"""
|
| 90 |
+
if not ground_truth or not retrieved:
|
| 91 |
+
return 0.0
|
| 92 |
+
|
| 93 |
+
relevant = 0
|
| 94 |
+
for doc in retrieved[:k]:
|
| 95 |
+
for gt_doc in ground_truth:
|
| 96 |
+
if self._documents_match(doc, gt_doc):
|
| 97 |
+
relevant += 1
|
| 98 |
+
break
|
| 99 |
+
|
| 100 |
+
return relevant / min(k, len(retrieved))
|
| 101 |
+
|
| 102 |
+
def _calculate_ndcg_at_k(self, retrieved: List[str], ground_truth: List[str], k: int) -> float:
|
| 103 |
+
"""Calculate Normalized Discounted Cumulative Gain at k"""
|
| 104 |
+
if not ground_truth or not retrieved:
|
| 105 |
+
return 0.0
|
| 106 |
+
|
| 107 |
+
# Calculate DCG@k
|
| 108 |
+
dcg = 0.0
|
| 109 |
+
for i, doc in enumerate(retrieved[:k], 1):
|
| 110 |
+
for gt_doc in ground_truth:
|
| 111 |
+
if self._documents_match(doc, gt_doc):
|
| 112 |
+
# Binary relevance (can be enhanced with graded relevance)
|
| 113 |
+
relevance = 1.0
|
| 114 |
+
dcg += relevance / math.log2(i + 1)
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
# Calculate ideal DCG (IDCG)
|
| 118 |
+
idcg = 0.0
|
| 119 |
+
num_relevant = min(k, len(ground_truth))
|
| 120 |
+
for i in range(1, num_relevant + 1):
|
| 121 |
+
idcg += 1.0 / math.log2(i + 1)
|
| 122 |
+
|
| 123 |
+
# nDCG = DCG / IDCG
|
| 124 |
+
if idcg == 0:
|
| 125 |
+
return 0.0
|
| 126 |
+
return dcg / idcg
|
| 127 |
+
|
| 128 |
+
def _calculate_hit_at_k(self, retrieved: List[str], ground_truth: List[str], k: int) -> float:
|
| 129 |
+
"""Calculate Hit@k metric"""
|
| 130 |
+
if not ground_truth:
|
| 131 |
+
return 0.0
|
| 132 |
+
|
| 133 |
+
# Simple exact match (can be enhanced with semantic matching)
|
| 134 |
+
for doc in retrieved[:k]:
|
| 135 |
+
for gt_doc in ground_truth:
|
| 136 |
+
if self._documents_match(doc, gt_doc):
|
| 137 |
+
return 1.0
|
| 138 |
+
return 0.0
|
| 139 |
+
|
| 140 |
+
def _calculate_mrr(self, retrieved: List[str], ground_truth: List[str]) -> float:
|
| 141 |
+
"""Calculate Mean Reciprocal Rank"""
|
| 142 |
+
if not ground_truth:
|
| 143 |
+
return 0.0
|
| 144 |
+
|
| 145 |
+
for rank, doc in enumerate(retrieved, 1):
|
| 146 |
+
for gt_doc in ground_truth:
|
| 147 |
+
if self._documents_match(doc, gt_doc):
|
| 148 |
+
return 1.0 / rank
|
| 149 |
+
return 0.0
|
| 150 |
+
|
| 151 |
+
def _calculate_semantic_similarity(self, retrieved: List[str], ground_truth: List[str]) -> float:
|
| 152 |
+
"""Calculate average semantic similarity"""
|
| 153 |
+
if not retrieved or not ground_truth:
|
| 154 |
+
return 0.0
|
| 155 |
+
|
| 156 |
+
# Use the same embedding model as the vector store
|
| 157 |
+
embeddings_retrieved = [self.rag_manager.vector_store.embed_text(doc) for doc in retrieved]
|
| 158 |
+
embeddings_gt = [self.rag_manager.vector_store.embed_text(doc) for doc in ground_truth]
|
| 159 |
+
|
| 160 |
+
# Calculate cosine similarity matrix
|
| 161 |
+
similarity_matrix = cosine_similarity(embeddings_retrieved, embeddings_gt)
|
| 162 |
+
|
| 163 |
+
# Return max similarity for each retrieved document, then average
|
| 164 |
+
max_similarities = np.max(similarity_matrix, axis=1)
|
| 165 |
+
return float(np.mean(max_similarities))
|
| 166 |
+
|
| 167 |
+
def _documents_match(self, doc1: str, doc2: str, threshold: float = 0.7) -> bool:
|
| 168 |
+
"""Check if two documents match (semantically or exactly)
|
| 169 |
+
|
| 170 |
+
Uses semantic similarity with a threshold. Also checks for exact substring matches
|
| 171 |
+
to handle cases where ground truth is a substring of the actual chunk.
|
| 172 |
+
"""
|
| 173 |
+
# Normalize strings for comparison
|
| 174 |
+
doc1_clean = doc1.strip().lower()
|
| 175 |
+
doc2_clean = doc2.strip().lower()
|
| 176 |
+
|
| 177 |
+
# Exact match or substring match (ground truth might be a substring of chunk)
|
| 178 |
+
if doc1_clean == doc2_clean or doc1_clean in doc2_clean or doc2_clean in doc1_clean:
|
| 179 |
+
return True
|
| 180 |
+
|
| 181 |
+
# Semantic similarity check
|
| 182 |
+
try:
|
| 183 |
+
embedding1 = self.rag_manager.vector_store.embed_text(doc1)
|
| 184 |
+
embedding2 = self.rag_manager.vector_store.embed_text(doc2)
|
| 185 |
+
similarity = cosine_similarity([embedding1], [embedding2])[0][0]
|
| 186 |
+
return similarity > threshold
|
| 187 |
+
except Exception as e:
|
| 188 |
+
# Fallback to exact match if embedding fails
|
| 189 |
+
return doc1_clean == doc2_clean
|
| 190 |
+
|
| 191 |
+
def batch_evaluate(self, queries: List[Dict[str, Any]],
|
| 192 |
+
output_file: Optional[str] = None,
|
| 193 |
+
pipelines: Optional[List[str]] = None,
|
| 194 |
+
user_satisfaction: Optional[Dict[str, int]] = None) -> Tuple[pd.DataFrame, Dict[str, Any], List[Dict[str, Any]]]:
|
| 195 |
+
"""
|
| 196 |
+
Batch evaluation on multiple queries across multiple pipelines.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
queries: List of query dictionaries
|
| 200 |
+
output_file: Optional output filename
|
| 201 |
+
pipelines: List of pipeline names to evaluate (default: all 4 methods)
|
| 202 |
+
user_satisfaction: Optional dict mapping query_id to satisfaction score (1-5)
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Tuple of (DataFrame, summary dict, raw results list)
|
| 206 |
+
"""
|
| 207 |
+
if pipelines is None:
|
| 208 |
+
pipelines = ['base_rag', 'tag_filter_rag', 'hybrid_rag', 'hybrid_rerank_rag']
|
| 209 |
+
|
| 210 |
+
results = []
|
| 211 |
+
all_latencies = defaultdict(list)
|
| 212 |
+
|
| 213 |
+
for i, query_data in enumerate(queries):
|
| 214 |
+
query_id = query_data.get('query_id', f'query_{i+1}')
|
| 215 |
+
query = query_data['query']
|
| 216 |
+
|
| 217 |
+
print(f"Evaluating query {i+1}/{len(queries)}: {query[:50]}...")
|
| 218 |
+
|
| 219 |
+
# Get user satisfaction score - try both query_id and index-based keys
|
| 220 |
+
user_sat_score = None
|
| 221 |
+
if user_satisfaction:
|
| 222 |
+
# Try query_id first
|
| 223 |
+
user_sat_score = user_satisfaction.get(query_id)
|
| 224 |
+
# If not found, try index-based key (query_0, query_1, etc.)
|
| 225 |
+
if user_sat_score is None:
|
| 226 |
+
user_sat_score = user_satisfaction.get(f'query_{i}')
|
| 227 |
+
|
| 228 |
+
query_result = {
|
| 229 |
+
'query_id': query_id,
|
| 230 |
+
'query': query,
|
| 231 |
+
'ground_truth': query_data.get('ground_truth', []),
|
| 232 |
+
'user_satisfaction': user_sat_score,
|
| 233 |
+
'pipelines': {}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
k_values = query_data.get('k_values', [1, 3, 5])
|
| 237 |
+
|
| 238 |
+
for pipeline in pipelines:
|
| 239 |
+
pipeline_results = {}
|
| 240 |
+
|
| 241 |
+
for k in k_values:
|
| 242 |
+
try:
|
| 243 |
+
# Retrieve using the specified pipeline
|
| 244 |
+
retrieval_result = self._retrieve_from_pipeline(
|
| 245 |
+
pipeline, query, k, query_data
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Calculate metrics
|
| 249 |
+
retrieved_docs = [source['content'] for source in retrieval_result.sources]
|
| 250 |
+
|
| 251 |
+
metrics = {
|
| 252 |
+
'hit_at_k': self._calculate_hit_at_k(retrieved_docs, query_data.get('ground_truth', []), k),
|
| 253 |
+
'precision_at_k': self._calculate_precision_at_k(retrieved_docs, query_data.get('ground_truth', []), k),
|
| 254 |
+
'ndcg_at_k': self._calculate_ndcg_at_k(retrieved_docs, query_data.get('ground_truth', []), k),
|
| 255 |
+
'mrr': self._calculate_mrr(retrieved_docs, query_data.get('ground_truth', [])),
|
| 256 |
+
'semantic_similarity': self._calculate_semantic_similarity(retrieved_docs, query_data.get('ground_truth', [])),
|
| 257 |
+
'latency': retrieval_result.latency,
|
| 258 |
+
'retrieved_count': len(retrieved_docs)
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
pipeline_results[k] = metrics
|
| 262 |
+
all_latencies[pipeline].append(retrieval_result.latency)
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
print(f"Error evaluating {pipeline} for query {query_id}: {e}")
|
| 266 |
+
pipeline_results[k] = {
|
| 267 |
+
'hit_at_k': 0.0,
|
| 268 |
+
'precision_at_k': 0.0,
|
| 269 |
+
'ndcg_at_k': 0.0,
|
| 270 |
+
'mrr': 0.0,
|
| 271 |
+
'semantic_similarity': 0.0,
|
| 272 |
+
'latency': 0.0,
|
| 273 |
+
'retrieved_count': 0
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
query_result['pipelines'][pipeline] = pipeline_results
|
| 277 |
+
|
| 278 |
+
results.append(query_result)
|
| 279 |
+
|
| 280 |
+
# Convert to DataFrame
|
| 281 |
+
df = self._results_to_dataframe(results)
|
| 282 |
+
|
| 283 |
+
# Calculate summary statistics
|
| 284 |
+
summary = self._calculate_summary_statistics(df, all_latencies)
|
| 285 |
+
|
| 286 |
+
# Save results if output file specified
|
| 287 |
+
if output_file:
|
| 288 |
+
import os
|
| 289 |
+
reports_dir = os.path.join(os.getcwd(), "reports")
|
| 290 |
+
os.makedirs(reports_dir, exist_ok=True)
|
| 291 |
+
|
| 292 |
+
csv_path = os.path.join(reports_dir, output_file)
|
| 293 |
+
json_path = os.path.join(reports_dir, output_file.replace('.csv', '.json'))
|
| 294 |
+
|
| 295 |
+
df.to_csv(csv_path, index=False)
|
| 296 |
+
|
| 297 |
+
# Save with summary - convert numpy types to Python native types
|
| 298 |
+
save_data = {
|
| 299 |
+
'results': self._convert_to_native_types(results),
|
| 300 |
+
'summary': self._convert_to_native_types(summary)
|
| 301 |
+
}
|
| 302 |
+
with open(json_path, 'w') as f:
|
| 303 |
+
json.dump(save_data, f, indent=2)
|
| 304 |
+
|
| 305 |
+
return df, summary, results
|
| 306 |
+
|
| 307 |
+
def _retrieve_from_pipeline(self, pipeline: str, query: str, k: int,
|
| 308 |
+
query_data: Dict[str, Any]) -> RetrievalResult:
|
| 309 |
+
"""Retrieve from the specified pipeline"""
|
| 310 |
+
if pipeline == 'base_rag':
|
| 311 |
+
return self.rag_manager.base_rag.retrieve(query, k)
|
| 312 |
+
elif pipeline == 'tag_filter_rag':
|
| 313 |
+
tags = query_data.get('tags')
|
| 314 |
+
return self.rag_manager.tag_filter_rag.retrieve(
|
| 315 |
+
query, k, tags=tags, tag_operator=query_data.get('tag_operator', 'OR')
|
| 316 |
+
)
|
| 317 |
+
elif pipeline == 'hybrid_rag':
|
| 318 |
+
tags = query_data.get('tags')
|
| 319 |
+
return self.rag_manager.hybrid_rag.retrieve(
|
| 320 |
+
query, k,
|
| 321 |
+
tags=tags,
|
| 322 |
+
tag_operator=query_data.get('tag_operator', 'OR'),
|
| 323 |
+
vector_weight=query_data.get('vector_weight', 0.7),
|
| 324 |
+
tag_weight=query_data.get('tag_weight', 0.3)
|
| 325 |
+
)
|
| 326 |
+
elif pipeline == 'hybrid_rerank_rag':
|
| 327 |
+
tags = query_data.get('tags')
|
| 328 |
+
return self.rag_manager.hybrid_rerank_rag.retrieve(
|
| 329 |
+
query, k,
|
| 330 |
+
tags=tags,
|
| 331 |
+
tag_operator=query_data.get('tag_operator', 'OR'),
|
| 332 |
+
vector_weight=query_data.get('vector_weight', 0.7),
|
| 333 |
+
tag_weight=query_data.get('tag_weight', 0.3)
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
raise ValueError(f"Unknown pipeline: {pipeline}")
|
| 337 |
+
|
| 338 |
+
def _calculate_summary_statistics(self, df: pd.DataFrame,
|
| 339 |
+
all_latencies: Dict[str, List[float]]) -> Dict[str, Any]:
|
| 340 |
+
"""Calculate aggregated summary statistics"""
|
| 341 |
+
summary = {}
|
| 342 |
+
|
| 343 |
+
# Aggregate by pipeline and k
|
| 344 |
+
for pipeline in df['pipeline'].unique():
|
| 345 |
+
summary[pipeline] = {}
|
| 346 |
+
pipeline_df = df[df['pipeline'] == pipeline]
|
| 347 |
+
|
| 348 |
+
for k in df['k'].unique():
|
| 349 |
+
# Convert numpy int64 to Python int for dictionary key
|
| 350 |
+
k_int = int(k) if isinstance(k, (np.integer, np.int64)) else k
|
| 351 |
+
k_df = pipeline_df[pipeline_df['k'] == k]
|
| 352 |
+
|
| 353 |
+
summary[pipeline][k_int] = {
|
| 354 |
+
'mean_precision_at_k': float(k_df['precision_at_k'].mean()),
|
| 355 |
+
'mean_ndcg_at_k': float(k_df['ndcg_at_k'].mean()),
|
| 356 |
+
'mean_hit_at_k': float(k_df['hit_at_k'].mean()),
|
| 357 |
+
'mean_mrr': float(k_df['mrr'].mean()),
|
| 358 |
+
'mean_semantic_similarity': float(k_df['semantic_similarity'].mean()),
|
| 359 |
+
'mean_latency': float(k_df['latency'].mean()),
|
| 360 |
+
'p50_latency': float(k_df['latency'].quantile(0.5)),
|
| 361 |
+
'p90_latency': float(k_df['latency'].quantile(0.9))
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
# Overall latency percentiles per pipeline
|
| 365 |
+
for pipeline, latencies in all_latencies.items():
|
| 366 |
+
if latencies:
|
| 367 |
+
summary[pipeline]['latency_percentiles'] = {
|
| 368 |
+
'p50': float(np.percentile(latencies, 50)),
|
| 369 |
+
'p90': float(np.percentile(latencies, 90))
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
return summary
|
| 373 |
+
|
| 374 |
+
def _convert_to_native_types(self, obj):
|
| 375 |
+
"""Recursively convert numpy types to Python native types for JSON serialization"""
|
| 376 |
+
if isinstance(obj, dict):
|
| 377 |
+
# Convert numpy int64 keys to Python int (json.dump can handle int keys)
|
| 378 |
+
return {int(k) if isinstance(k, (np.integer, np.int64, np.int32)) else k: self._convert_to_native_types(v)
|
| 379 |
+
for k, v in obj.items()}
|
| 380 |
+
elif isinstance(obj, list):
|
| 381 |
+
return [self._convert_to_native_types(item) for item in obj]
|
| 382 |
+
elif isinstance(obj, (np.integer, np.int64, np.int32)):
|
| 383 |
+
return int(obj)
|
| 384 |
+
elif isinstance(obj, (np.floating, np.float64, np.float32)):
|
| 385 |
+
return float(obj)
|
| 386 |
+
elif isinstance(obj, np.ndarray):
|
| 387 |
+
return obj.tolist()
|
| 388 |
+
elif isinstance(obj, (np.bool_, bool)):
|
| 389 |
+
return bool(obj)
|
| 390 |
+
elif obj is None:
|
| 391 |
+
return None
|
| 392 |
+
else:
|
| 393 |
+
return obj
|
| 394 |
+
|
| 395 |
+
def _results_to_dataframe(self, results: List[Dict[str, Any]]) -> pd.DataFrame:
|
| 396 |
+
"""Convert evaluation results to DataFrame"""
|
| 397 |
+
rows = []
|
| 398 |
+
|
| 399 |
+
for result in results:
|
| 400 |
+
query_id = result.get('query_id', 'unknown')
|
| 401 |
+
query = result['query']
|
| 402 |
+
|
| 403 |
+
for pipeline, pipeline_results in result.get('pipelines', {}).items():
|
| 404 |
+
for k, metrics in pipeline_results.items():
|
| 405 |
+
rows.append({
|
| 406 |
+
'query_id': query_id,
|
| 407 |
+
'query': query,
|
| 408 |
+
'k': k,
|
| 409 |
+
'pipeline': pipeline,
|
| 410 |
+
'hit_at_k': metrics.get('hit_at_k', 0.0),
|
| 411 |
+
'precision_at_k': metrics.get('precision_at_k', 0.0),
|
| 412 |
+
'ndcg_at_k': metrics.get('ndcg_at_k', 0.0),
|
| 413 |
+
'mrr': metrics.get('mrr', 0.0),
|
| 414 |
+
'semantic_similarity': metrics.get('semantic_similarity', 0.0),
|
| 415 |
+
'latency': metrics.get('latency', 0.0),
|
| 416 |
+
'retrieved_count': metrics.get('retrieved_count', 0),
|
| 417 |
+
'user_satisfaction': result.get('user_satisfaction')
|
| 418 |
+
})
|
| 419 |
+
|
| 420 |
+
return pd.DataFrame(rows)
|
core/index.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chromadb
|
| 2 |
+
from chromadb.config import Settings
|
| 3 |
+
from typing import List, Dict, Any, Optional, Union
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from .utils import Chunk
|
| 7 |
+
import os as _os
|
| 8 |
+
_OPENAI_EMBED = False
|
| 9 |
+
try:
|
| 10 |
+
from openai import OpenAI as _OpenAI
|
| 11 |
+
_OPENAI_EMBED = True if _os.getenv("OPENAI_API_KEY") else False
|
| 12 |
+
except Exception:
|
| 13 |
+
_OPENAI_EMBED = False
|
| 14 |
+
try:
|
| 15 |
+
from sentence_transformers import SentenceTransformer
|
| 16 |
+
except Exception:
|
| 17 |
+
SentenceTransformer = None
|
| 18 |
+
|
| 19 |
+
class VectorStore:
|
| 20 |
+
"""Vector database management"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, persist_directory: str = "/data/chroma"):
|
| 23 |
+
self.persist_directory = persist_directory
|
| 24 |
+
|
| 25 |
+
# Ensure directory exists with proper permissions before creating client
|
| 26 |
+
os.makedirs(persist_directory, exist_ok=True, mode=0o755)
|
| 27 |
+
|
| 28 |
+
self.client = chromadb.PersistentClient(path=persist_directory)
|
| 29 |
+
# Default to SentenceTransformers; runtime switching handled via configure_embeddings()
|
| 30 |
+
self.use_openai = False
|
| 31 |
+
if SentenceTransformer is None:
|
| 32 |
+
raise RuntimeError("SentenceTransformers not available. Install sentence-transformers or switch to OpenAI via UI.")
|
| 33 |
+
self.st_model_name = os.getenv("ST_EMBED_MODEL", "all-MiniLM-L6-v2")
|
| 34 |
+
# Use local_files_only=True to work offline (model should already be cached)
|
| 35 |
+
# This prevents network requests to HuggingFace Hub
|
| 36 |
+
try:
|
| 37 |
+
self.embedding_model = SentenceTransformer(self.st_model_name, local_files_only=True)
|
| 38 |
+
except Exception:
|
| 39 |
+
# Fallback: if local files not found, try with network (for first-time setup)
|
| 40 |
+
self.embedding_model = SentenceTransformer(self.st_model_name)
|
| 41 |
+
# Get model output dimension
|
| 42 |
+
try:
|
| 43 |
+
self.embed_dim = int(getattr(self.embedding_model, "get_sentence_embedding_dimension")())
|
| 44 |
+
except Exception:
|
| 45 |
+
# Fallback: compute once
|
| 46 |
+
self.embed_dim = len(self.embedding_model.encode("test"))
|
| 47 |
+
|
| 48 |
+
def _reopen_client(self, new_path: str):
|
| 49 |
+
os.makedirs(new_path, exist_ok=True, mode=0o755)
|
| 50 |
+
self.persist_directory = new_path
|
| 51 |
+
self.client = chromadb.PersistentClient(path=new_path)
|
| 52 |
+
|
| 53 |
+
def _collection_suffix(self) -> str:
|
| 54 |
+
provider = "oai" if self.use_openai else "st"
|
| 55 |
+
return f"{provider}_{self.embed_dim}"
|
| 56 |
+
|
| 57 |
+
def _resolve_collection_name(self, base_name: str) -> str:
|
| 58 |
+
"""Ensure separate collections per embedding dimension/provider to avoid mismatch."""
|
| 59 |
+
return f"{base_name}__{self._collection_suffix()}"
|
| 60 |
+
|
| 61 |
+
def configure_embeddings(self, use_openai: bool, openai_model: Optional[str] = None, st_model_name: Optional[str] = None):
|
| 62 |
+
"""Reconfigure embedding backend at runtime.
|
| 63 |
+
Switching providers/dimensions implies a new collection suffix; existing data remains under old suffix.
|
| 64 |
+
"""
|
| 65 |
+
self.use_openai = bool(use_openai)
|
| 66 |
+
if self.use_openai:
|
| 67 |
+
# Check at call-time to avoid stale module-level flags
|
| 68 |
+
if not os.getenv("OPENAI_API_KEY"):
|
| 69 |
+
raise RuntimeError("OpenAI not available or API key missing.")
|
| 70 |
+
self.openai_client = _OpenAI()
|
| 71 |
+
self.openai_model = openai_model or os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small")
|
| 72 |
+
if self.openai_model == "text-embedding-3-large":
|
| 73 |
+
self.embed_dim = 3072
|
| 74 |
+
else:
|
| 75 |
+
self.embed_dim = 1536
|
| 76 |
+
else:
|
| 77 |
+
if SentenceTransformer is None:
|
| 78 |
+
raise RuntimeError("SentenceTransformer not available.")
|
| 79 |
+
name = st_model_name or os.getenv("ST_EMBED_MODEL", "all-MiniLM-L6-v2")
|
| 80 |
+
# Only reload if changed
|
| 81 |
+
if not hasattr(self, 'st_model_name') or self.st_model_name != name:
|
| 82 |
+
self.st_model_name = name
|
| 83 |
+
# Use local_files_only=True to work offline
|
| 84 |
+
try:
|
| 85 |
+
self.embedding_model = SentenceTransformer(self.st_model_name, local_files_only=True)
|
| 86 |
+
except Exception:
|
| 87 |
+
# Fallback: if local files not found, try with network
|
| 88 |
+
self.embedding_model = SentenceTransformer(self.st_model_name)
|
| 89 |
+
try:
|
| 90 |
+
self.embed_dim = int(getattr(self.embedding_model, "get_sentence_embedding_dimension")())
|
| 91 |
+
except Exception:
|
| 92 |
+
self.embed_dim = len(self.embedding_model.encode("test"))
|
| 93 |
+
|
| 94 |
+
def create_collection(self, name: str) -> chromadb.Collection:
|
| 95 |
+
"""Create or get collection, namespaced by embedding provider/dimension."""
|
| 96 |
+
full_name = self._resolve_collection_name(name)
|
| 97 |
+
return self.client.get_or_create_collection(
|
| 98 |
+
name=full_name,
|
| 99 |
+
metadata={"hnsw:space": "cosine", "embed_dim": str(self.embed_dim)}
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def _embed_one(self, text: str) -> List[float]:
|
| 103 |
+
"""Generate embedding for a single text"""
|
| 104 |
+
if self.use_openai:
|
| 105 |
+
resp = self.openai_client.embeddings.create(model=self.openai_model, input=text)
|
| 106 |
+
return resp.data[0].embedding
|
| 107 |
+
return self.embedding_model.encode(text).tolist()
|
| 108 |
+
|
| 109 |
+
def _embed_batch(self, texts: List[str]) -> List[List[float]]:
|
| 110 |
+
"""Generate embeddings for a batch of texts"""
|
| 111 |
+
if self.use_openai:
|
| 112 |
+
resp = self.openai_client.embeddings.create(model=self.openai_model, input=texts)
|
| 113 |
+
return [d.embedding for d in resp.data]
|
| 114 |
+
return self.embedding_model.encode(texts).tolist()
|
| 115 |
+
|
| 116 |
+
# Backward-compat API used by evaluator
|
| 117 |
+
def embed_text(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
| 118 |
+
"""Compatibility wrapper:
|
| 119 |
+
- If text is str -> returns single embedding list[float]
|
| 120 |
+
- If text is list[str] -> returns list of embeddings list[list[float]]
|
| 121 |
+
"""
|
| 122 |
+
if isinstance(text, list):
|
| 123 |
+
return self._embed_batch(text)
|
| 124 |
+
return self._embed_one(text)
|
| 125 |
+
|
| 126 |
+
def add_documents(self, collection_name: str, chunks: List[Chunk]):
|
| 127 |
+
"""Add documents to vector store"""
|
| 128 |
+
collection = self.create_collection(collection_name)
|
| 129 |
+
|
| 130 |
+
# Generate embeddings
|
| 131 |
+
texts = [chunk.content for chunk in chunks]
|
| 132 |
+
embeddings = self._embed_batch(texts)
|
| 133 |
+
|
| 134 |
+
# Prepare metadata - ChromaDB doesn't support lists, so convert tags to string
|
| 135 |
+
metadatas = []
|
| 136 |
+
for chunk in chunks:
|
| 137 |
+
metadata = chunk.metadata.copy()
|
| 138 |
+
# Convert tags list to comma-separated string for ChromaDB
|
| 139 |
+
if 'tags' in metadata and isinstance(metadata['tags'], list):
|
| 140 |
+
metadata['tags'] = ', '.join(metadata['tags'])
|
| 141 |
+
# Add doc_id to metadata for document counting
|
| 142 |
+
if hasattr(chunk, 'doc_id') and chunk.doc_id:
|
| 143 |
+
metadata['doc_id'] = chunk.doc_id
|
| 144 |
+
metadatas.append(metadata)
|
| 145 |
+
|
| 146 |
+
ids = [chunk.chunk_id for chunk in chunks]
|
| 147 |
+
|
| 148 |
+
# Add to collection with writable fallback
|
| 149 |
+
try:
|
| 150 |
+
collection.add(
|
| 151 |
+
embeddings=embeddings,
|
| 152 |
+
documents=texts,
|
| 153 |
+
metadatas=metadatas,
|
| 154 |
+
ids=ids
|
| 155 |
+
)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
msg = str(e).lower()
|
| 158 |
+
if "readonly" in msg or "read-only" in msg:
|
| 159 |
+
# Fallback to a user-writable directory and retry once
|
| 160 |
+
fallback_dir = os.getenv("CHROMA_PERSIST_DIR", os.path.join(os.getcwd(), "chroma_data"))
|
| 161 |
+
if os.path.abspath(fallback_dir) == os.path.abspath(self.persist_directory):
|
| 162 |
+
# Choose a different fallback under the user's home cache
|
| 163 |
+
home_cache = os.path.join(os.path.expanduser("~"), ".cache", "rag-evaluation-system", "chroma")
|
| 164 |
+
fallback_dir = home_cache
|
| 165 |
+
self._reopen_client(fallback_dir)
|
| 166 |
+
collection = self.create_collection(collection_name)
|
| 167 |
+
collection.add(
|
| 168 |
+
embeddings=embeddings,
|
| 169 |
+
documents=texts,
|
| 170 |
+
metadatas=metadatas,
|
| 171 |
+
ids=ids
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
raise
|
| 175 |
+
|
| 176 |
+
def search(self, collection_name: str, query: str,
|
| 177 |
+
filters: Optional[Dict[str, Any]] = None,
|
| 178 |
+
tag_filters: Optional[Dict[str, Any]] = None,
|
| 179 |
+
k: int = 5) -> List[Dict[str, Any]]:
|
| 180 |
+
"""
|
| 181 |
+
Search in vector store with optional filters and tag filters.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
collection_name: Collection name
|
| 185 |
+
query: Search query
|
| 186 |
+
filters: Standard metadata filters (doc_type, etc.)
|
| 187 |
+
tag_filters: Tag filters dict with 'tags' (list) and 'operator' ('OR', 'AND', 'NOT')
|
| 188 |
+
k: Number of results to return
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
List of formatted search results
|
| 192 |
+
"""
|
| 193 |
+
collection = self.create_collection(collection_name)
|
| 194 |
+
|
| 195 |
+
# Generate query embedding
|
| 196 |
+
query_embedding = self._embed_one(query)
|
| 197 |
+
|
| 198 |
+
# Build where clause combining standard filters and tag filters
|
| 199 |
+
where_clause = self._build_where_clause(filters, tag_filters)
|
| 200 |
+
|
| 201 |
+
# For tag filtering (OR/AND/NOT), we need to fetch more results and post-filter
|
| 202 |
+
# ChromaDB stores tags as comma-separated strings, so we filter in-memory
|
| 203 |
+
fetch_k = k
|
| 204 |
+
if tag_filters and tag_filters.get('tags'):
|
| 205 |
+
fetch_k = k * 10 # Fetch more for post-filtering
|
| 206 |
+
|
| 207 |
+
# Perform search
|
| 208 |
+
results = collection.query(
|
| 209 |
+
query_embeddings=[query_embedding],
|
| 210 |
+
n_results=fetch_k,
|
| 211 |
+
where=where_clause,
|
| 212 |
+
include=["documents", "metadatas", "distances"]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Format results
|
| 216 |
+
formatted_results = []
|
| 217 |
+
for i in range(len(results['documents'][0])):
|
| 218 |
+
metadata = results['metadatas'][0][i].copy()
|
| 219 |
+
# Convert tags from string (ChromaDB format) back to list
|
| 220 |
+
if 'tags' in metadata and isinstance(metadata['tags'], str):
|
| 221 |
+
metadata['tags'] = [tag.strip() for tag in metadata['tags'].split(',') if tag.strip()]
|
| 222 |
+
|
| 223 |
+
formatted_results.append({
|
| 224 |
+
'content': results['documents'][0][i],
|
| 225 |
+
'metadata': metadata,
|
| 226 |
+
'distance': results['distances'][0][i],
|
| 227 |
+
'score': 1 - results['distances'][0][i], # Convert to similarity score
|
| 228 |
+
'id': results.get('ids', [None])[0][i] if results.get('ids') else None
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
# Post-filter for all tag operators (OR/AND/NOT)
|
| 232 |
+
# Since tags are stored as comma-separated strings, we filter in-memory
|
| 233 |
+
if tag_filters and tag_filters.get('tags'):
|
| 234 |
+
formatted_results = self._post_filter_tags(
|
| 235 |
+
formatted_results,
|
| 236 |
+
tag_filters['tags'],
|
| 237 |
+
tag_filters.get('operator', 'OR').upper()
|
| 238 |
+
)
|
| 239 |
+
formatted_results = formatted_results[:k]
|
| 240 |
+
|
| 241 |
+
return formatted_results
|
| 242 |
+
|
| 243 |
+
def _build_where_clause(self, filters: Optional[Dict[str, Any]],
|
| 244 |
+
tag_filters: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| 245 |
+
"""Build ChromaDB where clause from filters and tag filters."""
|
| 246 |
+
conditions = []
|
| 247 |
+
|
| 248 |
+
# Add standard filters
|
| 249 |
+
if filters:
|
| 250 |
+
for key, value in filters.items():
|
| 251 |
+
conditions.append({key: {"$eq": value}})
|
| 252 |
+
|
| 253 |
+
# Add tag filters
|
| 254 |
+
# Note: ChromaDB doesn't support list fields or $contains, so tags are stored as comma-separated strings
|
| 255 |
+
# All tag filtering (OR/AND/NOT) is done in-memory after retrieval
|
| 256 |
+
if tag_filters:
|
| 257 |
+
tags = tag_filters.get('tags', [])
|
| 258 |
+
operator = tag_filters.get('operator', 'OR').upper()
|
| 259 |
+
# For tag filtering, we need to fetch more results and filter in-memory
|
| 260 |
+
# Don't add tag filters to where clause
|
| 261 |
+
|
| 262 |
+
if not conditions:
|
| 263 |
+
return None
|
| 264 |
+
elif len(conditions) == 1:
|
| 265 |
+
return conditions[0]
|
| 266 |
+
else:
|
| 267 |
+
return {"$and": conditions}
|
| 268 |
+
|
| 269 |
+
def _post_filter_tags(self, results: List[Dict[str, Any]],
|
| 270 |
+
tags: List[str], operator: str) -> List[Dict[str, Any]]:
|
| 271 |
+
"""
|
| 272 |
+
Post-filter results based on tag operator (OR/AND/NOT).
|
| 273 |
+
ChromaDB stores tags as comma-separated strings, so we filter in-memory.
|
| 274 |
+
Uses case-insensitive matching and substring matching for flexibility.
|
| 275 |
+
"""
|
| 276 |
+
filtered = []
|
| 277 |
+
|
| 278 |
+
# Normalize search tags (lowercase, strip)
|
| 279 |
+
normalized_search_tags = [tag.strip().lower() for tag in tags if tag.strip()]
|
| 280 |
+
|
| 281 |
+
for result in results:
|
| 282 |
+
metadata = result.get('metadata', {})
|
| 283 |
+
result_tags = metadata.get('tags', [])
|
| 284 |
+
|
| 285 |
+
# Convert tags from string (ChromaDB format) to list if needed
|
| 286 |
+
if isinstance(result_tags, str):
|
| 287 |
+
result_tags = [tag.strip() for tag in result_tags.split(',') if tag.strip()]
|
| 288 |
+
elif not isinstance(result_tags, list):
|
| 289 |
+
result_tags = []
|
| 290 |
+
|
| 291 |
+
# Normalize result tags (lowercase, strip)
|
| 292 |
+
normalized_result_tags = [tag.strip().lower() for tag in result_tags if tag.strip()]
|
| 293 |
+
|
| 294 |
+
if operator == 'OR':
|
| 295 |
+
# Any tag must be present (case-insensitive, substring match)
|
| 296 |
+
# Check if any search tag matches any result tag (exact or substring)
|
| 297 |
+
matches = False
|
| 298 |
+
for search_tag in normalized_search_tags:
|
| 299 |
+
for result_tag in normalized_result_tags:
|
| 300 |
+
if search_tag == result_tag or search_tag in result_tag or result_tag in search_tag:
|
| 301 |
+
matches = True
|
| 302 |
+
break
|
| 303 |
+
if matches:
|
| 304 |
+
break
|
| 305 |
+
if matches:
|
| 306 |
+
filtered.append(result)
|
| 307 |
+
elif operator == 'AND':
|
| 308 |
+
# All tags must be present (case-insensitive, substring match)
|
| 309 |
+
all_match = True
|
| 310 |
+
for search_tag in normalized_search_tags:
|
| 311 |
+
tag_matches = False
|
| 312 |
+
for result_tag in normalized_result_tags:
|
| 313 |
+
if search_tag == result_tag or search_tag in result_tag or result_tag in search_tag:
|
| 314 |
+
tag_matches = True
|
| 315 |
+
break
|
| 316 |
+
if not tag_matches:
|
| 317 |
+
all_match = False
|
| 318 |
+
break
|
| 319 |
+
if all_match:
|
| 320 |
+
filtered.append(result)
|
| 321 |
+
elif operator == 'NOT':
|
| 322 |
+
# None of the tags should be present (case-insensitive)
|
| 323 |
+
no_match = True
|
| 324 |
+
for search_tag in normalized_search_tags:
|
| 325 |
+
for result_tag in normalized_result_tags:
|
| 326 |
+
if search_tag == result_tag or search_tag in result_tag or result_tag in search_tag:
|
| 327 |
+
no_match = False
|
| 328 |
+
break
|
| 329 |
+
if not no_match:
|
| 330 |
+
break
|
| 331 |
+
if no_match:
|
| 332 |
+
filtered.append(result)
|
| 333 |
+
|
| 334 |
+
return filtered
|
| 335 |
+
|
| 336 |
+
def get_collection_stats(self, collection_name: str) -> Dict[str, Any]:
|
| 337 |
+
"""Get collection statistics"""
|
| 338 |
+
try:
|
| 339 |
+
collection = self.create_collection(collection_name)
|
| 340 |
+
total_chunks = collection.count()
|
| 341 |
+
|
| 342 |
+
# Count unique documents by getting all metadata and counting unique doc_id or source_name
|
| 343 |
+
# Since ChromaDB doesn't have a direct way to get unique values, we fetch all metadata
|
| 344 |
+
unique_docs = set()
|
| 345 |
+
try:
|
| 346 |
+
# Try to get all items to count unique documents
|
| 347 |
+
# Use get() with a reasonable limit to avoid memory issues
|
| 348 |
+
# If collection is large, we might need to handle pagination
|
| 349 |
+
if total_chunks > 0:
|
| 350 |
+
# Fetch all chunks (up to 10k limit, adjust if needed)
|
| 351 |
+
limit = min(10000, total_chunks)
|
| 352 |
+
results = collection.get(limit=limit)
|
| 353 |
+
|
| 354 |
+
if results and results.get('metadatas'):
|
| 355 |
+
# Extract unique doc_ids or source_names
|
| 356 |
+
for metadata in results['metadatas']:
|
| 357 |
+
# Try doc_id first (most reliable for unique document counting)
|
| 358 |
+
doc_id = metadata.get('doc_id')
|
| 359 |
+
if doc_id:
|
| 360 |
+
unique_docs.add(doc_id)
|
| 361 |
+
else:
|
| 362 |
+
# Fallback to source_name if doc_id not available (for older data)
|
| 363 |
+
source_name = metadata.get('source_name')
|
| 364 |
+
if source_name:
|
| 365 |
+
unique_docs.add(source_name)
|
| 366 |
+
|
| 367 |
+
doc_count = len(unique_docs) if unique_docs else 0
|
| 368 |
+
else:
|
| 369 |
+
doc_count = 0
|
| 370 |
+
except Exception as e:
|
| 371 |
+
# If we can't get metadata, fall back to using chunk count as estimate
|
| 372 |
+
# This might overcount if documents have multiple chunks, but it's better than 0
|
| 373 |
+
doc_count = total_chunks
|
| 374 |
+
import logging
|
| 375 |
+
logging.getLogger("rag_vector_store").warning(f"Could not count unique documents: {e}, using chunk count as estimate")
|
| 376 |
+
|
| 377 |
+
return {
|
| 378 |
+
"document_count": doc_count,
|
| 379 |
+
"chunk_count": total_chunks,
|
| 380 |
+
"collection_name": self._resolve_collection_name(collection_name)
|
| 381 |
+
}
|
| 382 |
+
except Exception as e:
|
| 383 |
+
import logging
|
| 384 |
+
logging.getLogger("rag_vector_store").warning(f"Failed to get collection stats: {e}")
|
| 385 |
+
return {
|
| 386 |
+
"document_count": 0,
|
| 387 |
+
"chunk_count": 0,
|
| 388 |
+
"collection_name": self._resolve_collection_name(collection_name)
|
| 389 |
+
}
|
core/ingest.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import PyPDF2
|
| 6 |
+
from .utils import Chunk, TextProcessor, generate_id
|
| 7 |
+
import logging as _logging
|
| 8 |
+
_logger = _logging.getLogger("rag_ingest")
|
| 9 |
+
import os as _os
|
| 10 |
+
_OPENAI_ENABLED = False
|
| 11 |
+
try:
|
| 12 |
+
from openai import OpenAI as _OpenAI
|
| 13 |
+
_OPENAI_ENABLED = True if _os.getenv("OPENAI_API_KEY") else False
|
| 14 |
+
except Exception:
|
| 15 |
+
_OPENAI_ENABLED = False
|
| 16 |
+
|
| 17 |
+
class OpenAIMetadataDetector:
|
| 18 |
+
"""Use OpenAI to detect language, doc_type, and hierarchy levels for a chunk.
|
| 19 |
+
Falls back to heuristics when OpenAI is not available.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, hierarchy_manager: 'HierarchyManager'):
|
| 22 |
+
self.hierarchy_manager = hierarchy_manager
|
| 23 |
+
self.client = _OpenAI() if _OPENAI_ENABLED else None
|
| 24 |
+
self.model = _os.getenv("OPENAI_MODEL", "gpt-4o-mini")
|
| 25 |
+
|
| 26 |
+
def detect(self, text: str) -> Dict[str, Any]:
|
| 27 |
+
if not self.client:
|
| 28 |
+
return {}
|
| 29 |
+
hierarchies = self.hierarchy_manager.list_hierarchies()
|
| 30 |
+
prompt = (
|
| 31 |
+
"You are a metadata extractor. Given a text chunk, infer: language (en|ja), "
|
| 32 |
+
"document_type (Policy|Manual|FAQ|Report|Note|Guideline), hierarchy_name, level1, level2, level3. "
|
| 33 |
+
"CRITICAL: hierarchy_name MUST be exactly one of the following: "
|
| 34 |
+
f"{hierarchies}. Do not invent other names. "
|
| 35 |
+
"Respond as strict JSON with keys: language, document_type, hierarchy_name, level1, level2, level3. "
|
| 36 |
+
"Be concise; if unsure, pick the closest.\n\nText:\n" + text[:2000]
|
| 37 |
+
)
|
| 38 |
+
try:
|
| 39 |
+
_logger.debug("Calling OpenAI for chunk metadata detection (model=%s)", self.model)
|
| 40 |
+
resp = self.client.chat.completions.create(
|
| 41 |
+
model=self.model,
|
| 42 |
+
messages=[{"role": "user", "content": prompt}],
|
| 43 |
+
temperature=0.0,
|
| 44 |
+
)
|
| 45 |
+
content = resp.choices[0].message.content
|
| 46 |
+
import json as _json
|
| 47 |
+
data = _json.loads(content)
|
| 48 |
+
# Enforce allowed hierarchy set
|
| 49 |
+
if isinstance(data, dict) and data.get("hierarchy_name") not in hierarchies:
|
| 50 |
+
data["hierarchy_name"] = None
|
| 51 |
+
_logger.debug("OpenAI chunk metadata inferred: %s", data)
|
| 52 |
+
return data if isinstance(data, dict) else {}
|
| 53 |
+
except Exception:
|
| 54 |
+
_logger.exception("OpenAI chunk metadata detection failed; using heuristics.")
|
| 55 |
+
return {}
|
| 56 |
+
|
| 57 |
+
# Try to import pypdf (newer, more robust PDF library)
|
| 58 |
+
try:
|
| 59 |
+
from pypdf import PdfReader as PyPdfReader
|
| 60 |
+
PYPDF_AVAILABLE = True
|
| 61 |
+
except ImportError:
|
| 62 |
+
PYPDF_AVAILABLE = False
|
| 63 |
+
|
| 64 |
+
class DocumentLoader:
|
| 65 |
+
"""Load documents from various formats"""
|
| 66 |
+
|
| 67 |
+
def __init__(self):
|
| 68 |
+
self.text_processor = TextProcessor()
|
| 69 |
+
|
| 70 |
+
def load_pdf(self, file_path: str) -> str:
|
| 71 |
+
"""Load text from PDF file with fallback readers, preserving paragraphs"""
|
| 72 |
+
# Validate file exists and is readable
|
| 73 |
+
if not os.path.exists(file_path):
|
| 74 |
+
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
| 75 |
+
|
| 76 |
+
if not os.path.isfile(file_path):
|
| 77 |
+
raise ValueError(f"Path is not a file: {file_path}")
|
| 78 |
+
|
| 79 |
+
# Check file size
|
| 80 |
+
file_size = os.path.getsize(file_path)
|
| 81 |
+
if file_size == 0:
|
| 82 |
+
raise ValueError(f"PDF file is empty: {file_path}")
|
| 83 |
+
|
| 84 |
+
# Try pypdf first (more robust)
|
| 85 |
+
if PYPDF_AVAILABLE:
|
| 86 |
+
try:
|
| 87 |
+
with open(file_path, 'rb') as file:
|
| 88 |
+
reader = PyPdfReader(file)
|
| 89 |
+
text = ""
|
| 90 |
+
for page in reader.pages:
|
| 91 |
+
page_text = page.extract_text()
|
| 92 |
+
if page_text:
|
| 93 |
+
text += page_text + "\n"
|
| 94 |
+
if text.strip():
|
| 95 |
+
return self.text_processor.clean_text_preserve_newlines(text)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
# If pypdf fails, try PyPDF2 as fallback
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
# Fallback to PyPDF2
|
| 101 |
+
try:
|
| 102 |
+
with open(file_path, 'rb') as file:
|
| 103 |
+
# Try to read with strict=False for corrupted PDFs
|
| 104 |
+
try:
|
| 105 |
+
reader = PyPDF2.PdfReader(file, strict=False)
|
| 106 |
+
except:
|
| 107 |
+
# If strict=False doesn't work, try normal reader
|
| 108 |
+
file.seek(0)
|
| 109 |
+
reader = PyPDF2.PdfReader(file)
|
| 110 |
+
|
| 111 |
+
text = ""
|
| 112 |
+
for i, page in enumerate(reader.pages):
|
| 113 |
+
try:
|
| 114 |
+
page_text = page.extract_text()
|
| 115 |
+
if page_text:
|
| 116 |
+
text += page_text + "\n"
|
| 117 |
+
except Exception as page_error:
|
| 118 |
+
# Skip pages that can't be extracted
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
if not text.strip():
|
| 122 |
+
raise ValueError(f"No text could be extracted from PDF: {file_path}")
|
| 123 |
+
|
| 124 |
+
return self.text_processor.clean_text_preserve_newlines(text)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
error_msg = str(e)
|
| 127 |
+
if "EOF marker not found" in error_msg or "EOF" in error_msg:
|
| 128 |
+
raise Exception(
|
| 129 |
+
f"PDF file appears to be corrupted or incomplete: {file_path}. "
|
| 130 |
+
f"This may be due to an incomplete upload or corrupted file. "
|
| 131 |
+
f"Please try re-uploading the file or check if the PDF is valid."
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
raise Exception(f"Error loading PDF {file_path}: {error_msg}")
|
| 135 |
+
|
| 136 |
+
def load_txt(self, file_path: str) -> str:
|
| 137 |
+
"""Load text from TXT file preserving paragraphs"""
|
| 138 |
+
try:
|
| 139 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 140 |
+
text = file.read()
|
| 141 |
+
return self.text_processor.clean_text_preserve_newlines(text)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
raise Exception(f"Error loading TXT {file_path}: {str(e)}")
|
| 144 |
+
|
| 145 |
+
def load_document(self, file_path: str) -> str:
|
| 146 |
+
"""Load document based on file extension"""
|
| 147 |
+
ext = Path(file_path).suffix.lower()
|
| 148 |
+
if ext == '.pdf':
|
| 149 |
+
return self.load_pdf(file_path)
|
| 150 |
+
elif ext == '.txt':
|
| 151 |
+
return self.load_txt(file_path)
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"Unsupported file format: {ext}")
|
| 154 |
+
|
| 155 |
+
class HierarchyManager:
|
| 156 |
+
"""Manage hierarchical metadata definitions"""
|
| 157 |
+
|
| 158 |
+
def __init__(self, hierarchies_dir: str = "hierarchies"):
|
| 159 |
+
self.hierarchies_dir = Path(hierarchies_dir)
|
| 160 |
+
self.hierarchies = {}
|
| 161 |
+
self.load_hierarchies()
|
| 162 |
+
|
| 163 |
+
def load_hierarchies(self):
|
| 164 |
+
"""Load all hierarchy definitions"""
|
| 165 |
+
for yaml_file in self.hierarchies_dir.glob("*.yaml"):
|
| 166 |
+
with open(yaml_file, 'r', encoding='utf-8') as file:
|
| 167 |
+
hierarchy_name = yaml_file.stem
|
| 168 |
+
self.hierarchies[hierarchy_name] = yaml.safe_load(file)
|
| 169 |
+
|
| 170 |
+
def get_hierarchy(self, name: str) -> Dict[str, Any]:
|
| 171 |
+
"""Get hierarchy definition by name"""
|
| 172 |
+
if name not in self.hierarchies:
|
| 173 |
+
raise ValueError(f"Hierarchy '{name}' not found")
|
| 174 |
+
return self.hierarchies[name]
|
| 175 |
+
|
| 176 |
+
def list_hierarchies(self) -> List[str]:
|
| 177 |
+
"""List available hierarchies"""
|
| 178 |
+
return list(self.hierarchies.keys())
|
| 179 |
+
|
| 180 |
+
class DocumentChunker:
|
| 181 |
+
"""Chunk documents with hierarchical metadata"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
|
| 184 |
+
self.chunk_size = chunk_size
|
| 185 |
+
self.chunk_overlap = chunk_overlap
|
| 186 |
+
self.text_processor = TextProcessor()
|
| 187 |
+
self.hierarchy_manager = HierarchyManager()
|
| 188 |
+
self.ai_detector = OpenAIMetadataDetector(self.hierarchy_manager)
|
| 189 |
+
|
| 190 |
+
def chunk_document(self, file_path: str, hierarchy: Optional[str],
|
| 191 |
+
doc_type: Optional[str], language: Optional[str]) -> List[Chunk]:
|
| 192 |
+
"""Chunk document with hierarchical metadata per chunk.
|
| 193 |
+
- Auto-detects hierarchy/doc_type/language when None or 'Auto'.
|
| 194 |
+
- Assigns metadata per chunk to support multi-topic documents.
|
| 195 |
+
"""
|
| 196 |
+
loader = DocumentLoader()
|
| 197 |
+
content = loader.load_document(file_path)
|
| 198 |
+
|
| 199 |
+
# Auto-detect language if needed
|
| 200 |
+
if not language or str(language).lower() == 'auto':
|
| 201 |
+
# Prefer OpenAI if available
|
| 202 |
+
ai_guess = self.ai_detector.detect(content)
|
| 203 |
+
_logger.debug("Language auto-detect: ai_guess=%s", ai_guess.get('language') if isinstance(ai_guess, dict) else None)
|
| 204 |
+
language = ai_guess.get('language') if isinstance(ai_guess, dict) and ai_guess.get('language') in ('en','ja') else (
|
| 205 |
+
'ja' if any('\u3040' <= ch <= '\u30ff' or '\u4e00' <= ch <= '\u9faf' for ch in content) else 'en'
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Prepare list of hierarchy names and definitions
|
| 209 |
+
hier_names = self.hierarchy_manager.list_hierarchies()
|
| 210 |
+
|
| 211 |
+
# If hierarchy is auto, we'll pick best per-chunk later; else load the chosen one
|
| 212 |
+
fixed_hierarchy_def = None
|
| 213 |
+
if hierarchy and hierarchy.lower() != 'auto':
|
| 214 |
+
fixed_hierarchy_def = self.hierarchy_manager.get_hierarchy(hierarchy)
|
| 215 |
+
|
| 216 |
+
# Simple structural chunking: split on double newlines first, then fall back to token windows
|
| 217 |
+
raw_blocks = [b.strip() for b in content.split('\n\n') if b.strip()]
|
| 218 |
+
if not raw_blocks:
|
| 219 |
+
raw_blocks = [content]
|
| 220 |
+
|
| 221 |
+
# Further split large blocks into overlapping windows
|
| 222 |
+
processed_blocks: List[str] = []
|
| 223 |
+
for block in raw_blocks:
|
| 224 |
+
words = block.split()
|
| 225 |
+
if len(words) <= self.chunk_size:
|
| 226 |
+
processed_blocks.append(block)
|
| 227 |
+
else:
|
| 228 |
+
step = max(1, self.chunk_size - self.chunk_overlap)
|
| 229 |
+
for i in range(0, len(words), step):
|
| 230 |
+
processed_blocks.append(' '.join(words[i:i + self.chunk_size]))
|
| 231 |
+
|
| 232 |
+
# Phase 1: provisional labels for each block
|
| 233 |
+
provisional: List[Dict[str, Any]] = []
|
| 234 |
+
# Sticky explicit labels propagate until overridden by new explicit labels
|
| 235 |
+
sticky_l1: Optional[str] = None
|
| 236 |
+
sticky_l2: Optional[str] = None
|
| 237 |
+
for block in processed_blocks:
|
| 238 |
+
ai_used = False
|
| 239 |
+
ph_hdef = fixed_hierarchy_def
|
| 240 |
+
ph_hname = hierarchy if hierarchy and hierarchy.lower() != 'auto' else None
|
| 241 |
+
if ph_hdef is None:
|
| 242 |
+
ai_guess = self.ai_detector.detect(block)
|
| 243 |
+
guess_name = ai_guess.get('hierarchy_name') if isinstance(ai_guess, dict) else None
|
| 244 |
+
# 0) Explicit label "Hierarchy: <name>"
|
| 245 |
+
import re
|
| 246 |
+
mH = re.search(r"^\s*hierarchy\s*:\s*(.+)$", block, flags=re.IGNORECASE | re.MULTILINE)
|
| 247 |
+
if mH:
|
| 248 |
+
explicit_h = mH.group(1).strip().lower()
|
| 249 |
+
for name in hier_names:
|
| 250 |
+
if name.lower() in explicit_h or explicit_h in name.lower():
|
| 251 |
+
ph_hdef = self.hierarchy_manager.get_hierarchy(name)
|
| 252 |
+
ph_hname = name
|
| 253 |
+
ai_used = ai_used or False
|
| 254 |
+
|
| 255 |
+
# 1) If OpenAI guessed a known hierarchy
|
| 256 |
+
if ph_hdef is None and guess_name in hier_names:
|
| 257 |
+
ph_hdef = self.hierarchy_manager.get_hierarchy(guess_name)
|
| 258 |
+
ph_hname = guess_name
|
| 259 |
+
ai_used = True
|
| 260 |
+
# 2) Weighted keyword scoring across all hierarchies (level1/2/3 + doc_types + filename hints)
|
| 261 |
+
if ph_hdef is None:
|
| 262 |
+
best_score = -1
|
| 263 |
+
best_name = None
|
| 264 |
+
best_def = None
|
| 265 |
+
block_lower = block.lower()
|
| 266 |
+
filename_lower = os.path.basename(file_path).lower()
|
| 267 |
+
for name in hier_names:
|
| 268 |
+
hdef = self.hierarchy_manager.get_hierarchy(name)
|
| 269 |
+
score = 0
|
| 270 |
+
# level1
|
| 271 |
+
for v in hdef['levels']['level1']['values']:
|
| 272 |
+
if v.lower() in block_lower:
|
| 273 |
+
score += 2
|
| 274 |
+
# level2
|
| 275 |
+
for l2_list in hdef['levels']['level2']['values'].values():
|
| 276 |
+
for v in l2_list:
|
| 277 |
+
if v.lower() in block_lower:
|
| 278 |
+
score += 2
|
| 279 |
+
# level3
|
| 280 |
+
for l3_list in hdef['levels']['level3']['values'].values():
|
| 281 |
+
for v in l3_list:
|
| 282 |
+
if v.lower() in block_lower:
|
| 283 |
+
score += 1
|
| 284 |
+
# doc_types
|
| 285 |
+
for dt in hdef.get('doc_types', []):
|
| 286 |
+
if dt.lower() in block_lower:
|
| 287 |
+
score += 1
|
| 288 |
+
# filename hint
|
| 289 |
+
if name.lower() in filename_lower:
|
| 290 |
+
score += 3
|
| 291 |
+
if score > best_score:
|
| 292 |
+
best_score = score
|
| 293 |
+
best_name = name
|
| 294 |
+
best_def = hdef
|
| 295 |
+
ph_hdef = best_def if best_def is not None else self.hierarchy_manager.get_hierarchy(hier_names[0])
|
| 296 |
+
ph_hname = best_name or hier_names[0]
|
| 297 |
+
|
| 298 |
+
ph_dtype = doc_type
|
| 299 |
+
if not doc_type or str(doc_type).lower() == 'auto':
|
| 300 |
+
ai_guess = self.ai_detector.detect(block)
|
| 301 |
+
if isinstance(ai_guess, dict) and ai_guess.get('document_type'):
|
| 302 |
+
ph_dtype = ai_guess['document_type']
|
| 303 |
+
ai_used = True
|
| 304 |
+
else:
|
| 305 |
+
dt_candidates = ph_hdef.get('doc_types', ["Policy", "Manual", "FAQ", "Report", "Note", "Guideline"])
|
| 306 |
+
block_lower = block.lower()
|
| 307 |
+
best_dt = dt_candidates[0]
|
| 308 |
+
best_score = -1
|
| 309 |
+
for dt in dt_candidates:
|
| 310 |
+
s = 0
|
| 311 |
+
if dt.lower() in block_lower:
|
| 312 |
+
s += 1
|
| 313 |
+
if dt.lower() == 'faq' and ('faq' in block_lower or 'q:' in block_lower):
|
| 314 |
+
s += 1
|
| 315 |
+
if dt.lower() == 'report' and ('report' in block_lower or 'summary' in block_lower):
|
| 316 |
+
s += 1
|
| 317 |
+
if s > best_score:
|
| 318 |
+
best_score = s
|
| 319 |
+
best_dt = dt
|
| 320 |
+
ph_dtype = best_dt
|
| 321 |
+
|
| 322 |
+
content_lower = block.lower()
|
| 323 |
+
# Detect explicit labels in this block
|
| 324 |
+
import re
|
| 325 |
+
exp_l1 = exp_l2 = None
|
| 326 |
+
m1 = re.search(r"^\s*domain\s*:\s*(.+)$", content_lower, flags=re.MULTILINE)
|
| 327 |
+
m2 = re.search(r"^\s*section\s*:\s*(.+)$", content_lower, flags=re.MULTILINE)
|
| 328 |
+
if m1:
|
| 329 |
+
exp_l1 = m1.group(1).strip()
|
| 330 |
+
if m2:
|
| 331 |
+
exp_l2 = m2.group(1).strip()
|
| 332 |
+
|
| 333 |
+
# Provisional levels
|
| 334 |
+
ph_l1 = self._classify_level1(content_lower, ph_hdef)
|
| 335 |
+
ph_l2 = self._classify_level2(content_lower, ph_hdef, ph_l1)
|
| 336 |
+
|
| 337 |
+
# Override with explicit labels when present
|
| 338 |
+
def _best_match(name: str, candidates: list[str]) -> str:
|
| 339 |
+
name_l = name.lower()
|
| 340 |
+
for c in candidates:
|
| 341 |
+
cl = c.lower()
|
| 342 |
+
if cl == name_l or name_l in cl or cl in name_l:
|
| 343 |
+
return c
|
| 344 |
+
return candidates[0] if candidates else "General"
|
| 345 |
+
|
| 346 |
+
if exp_l1:
|
| 347 |
+
ph_l1 = _best_match(exp_l1, ph_hdef['levels']['level1']['values'])
|
| 348 |
+
sticky_l1 = ph_l1
|
| 349 |
+
if exp_l2:
|
| 350 |
+
l2_candidates = ph_hdef['levels']['level2']['values'].get(ph_l1, [])
|
| 351 |
+
ph_l2 = _best_match(exp_l2, l2_candidates)
|
| 352 |
+
sticky_l2 = ph_l2
|
| 353 |
+
|
| 354 |
+
# Apply sticky labels when no explicit labels in this block
|
| 355 |
+
if not exp_l1 and sticky_l1:
|
| 356 |
+
ph_l1 = sticky_l1
|
| 357 |
+
if not exp_l2 and sticky_l2 and ph_hdef['levels']['level2']['values'].get(ph_l1):
|
| 358 |
+
ph_l2 = sticky_l2
|
| 359 |
+
|
| 360 |
+
provisional.append({
|
| 361 |
+
'text': block,
|
| 362 |
+
'hdef': ph_hdef,
|
| 363 |
+
'hname': ph_hname,
|
| 364 |
+
'dtype': ph_dtype,
|
| 365 |
+
'l1': ph_l1,
|
| 366 |
+
'l2': ph_l2,
|
| 367 |
+
'ai': ai_used
|
| 368 |
+
})
|
| 369 |
+
|
| 370 |
+
# Phase 2: merge adjacent blocks with same labels within size limit
|
| 371 |
+
merged_texts: List[str] = []
|
| 372 |
+
merged_meta: List[Dict[str, Any]] = []
|
| 373 |
+
if provisional:
|
| 374 |
+
current_text = provisional[0]['text']
|
| 375 |
+
current_meta = provisional[0]
|
| 376 |
+
for p in provisional[1:]:
|
| 377 |
+
same = (p['hname'] == current_meta['hname'] and p['l1'] == current_meta['l1'] and p['l2'] == current_meta['l2'])
|
| 378 |
+
candidate = current_text + "\n\n" + p['text'] if same else current_text
|
| 379 |
+
if same and self.text_processor.count_tokens(candidate) <= self.text_processor.count_tokens(current_text) + self.chunk_size:
|
| 380 |
+
current_text = candidate
|
| 381 |
+
current_meta['ai'] = current_meta['ai'] or p['ai']
|
| 382 |
+
else:
|
| 383 |
+
merged_texts.append(current_text)
|
| 384 |
+
merged_meta.append(current_meta)
|
| 385 |
+
current_text = p['text']
|
| 386 |
+
current_meta = p
|
| 387 |
+
merged_texts.append(current_text)
|
| 388 |
+
merged_meta.append(current_meta)
|
| 389 |
+
|
| 390 |
+
# Phase 3: finalize chunks
|
| 391 |
+
chunks: List[Chunk] = []
|
| 392 |
+
for text_block, meta in zip(merged_texts, merged_meta):
|
| 393 |
+
final_md = self._generate_metadata(
|
| 394 |
+
file_path=file_path,
|
| 395 |
+
hierarchy_def=meta['hdef'],
|
| 396 |
+
doc_type=meta['dtype'],
|
| 397 |
+
language=language,
|
| 398 |
+
content=text_block
|
| 399 |
+
)
|
| 400 |
+
if meta['hname']:
|
| 401 |
+
final_md['hierarchy'] = meta['hname']
|
| 402 |
+
final_md['ai_detected'] = meta['ai']
|
| 403 |
+
|
| 404 |
+
chunks.append(Chunk(
|
| 405 |
+
doc_id=generate_id(),
|
| 406 |
+
chunk_id=generate_id(),
|
| 407 |
+
content=text_block,
|
| 408 |
+
metadata=final_md
|
| 409 |
+
))
|
| 410 |
+
|
| 411 |
+
return chunks
|
| 412 |
+
|
| 413 |
+
def _generate_metadata(self, file_path: str, hierarchy_def: Dict[str, Any],
|
| 414 |
+
doc_type: str, language: str, content: str) -> Dict[str, Any]:
|
| 415 |
+
"""Generate hierarchical metadata for chunk"""
|
| 416 |
+
# Simple rule-based classification with explicit label override
|
| 417 |
+
content_lower = content.lower()
|
| 418 |
+
|
| 419 |
+
# 1) Try to honor explicit labels like "Domain:", "Section:", "Topic:"
|
| 420 |
+
import re
|
| 421 |
+
explicit_l1 = explicit_l2 = explicit_l3 = None
|
| 422 |
+
m1 = re.search(r"^\s*domain\s*:\s*(.+)$", content_lower, flags=re.MULTILINE)
|
| 423 |
+
m2 = re.search(r"^\s*section\s*:\s*(.+)$", content_lower, flags=re.MULTILINE)
|
| 424 |
+
m3 = re.search(r"^\s*topic\s*:\s*(.+)$", content_lower, flags=re.MULTILINE)
|
| 425 |
+
if m1:
|
| 426 |
+
explicit_l1 = m1.group(1).strip()
|
| 427 |
+
if m2:
|
| 428 |
+
explicit_l2 = m2.group(1).strip()
|
| 429 |
+
if m3:
|
| 430 |
+
explicit_l3 = m3.group(1).strip()
|
| 431 |
+
|
| 432 |
+
def _best_match(name: str, candidates: list[str]) -> str:
|
| 433 |
+
name_l = name.lower()
|
| 434 |
+
# exact contains
|
| 435 |
+
for c in candidates:
|
| 436 |
+
if c.lower() == name_l or name_l in c.lower() or c.lower() in name_l:
|
| 437 |
+
return c
|
| 438 |
+
# fallback: first candidate
|
| 439 |
+
return candidates[0] if candidates else "General"
|
| 440 |
+
|
| 441 |
+
if explicit_l1:
|
| 442 |
+
level1 = _best_match(explicit_l1, hierarchy_def['levels']['level1']['values'])
|
| 443 |
+
else:
|
| 444 |
+
level1 = self._classify_level1(content_lower, hierarchy_def)
|
| 445 |
+
|
| 446 |
+
if explicit_l2:
|
| 447 |
+
level2_candidates = hierarchy_def['levels']['level2']['values'].get(level1, [])
|
| 448 |
+
level2 = _best_match(explicit_l2, level2_candidates)
|
| 449 |
+
else:
|
| 450 |
+
level2 = self._classify_level2(content_lower, hierarchy_def, level1)
|
| 451 |
+
|
| 452 |
+
if explicit_l3:
|
| 453 |
+
level3_candidates = hierarchy_def['levels']['level3']['values'].get(level2, [])
|
| 454 |
+
level3 = _best_match(explicit_l3, level3_candidates)
|
| 455 |
+
else:
|
| 456 |
+
level3 = self._classify_level3(content_lower, hierarchy_def, level1, level2)
|
| 457 |
+
|
| 458 |
+
# Fallback mapping to 'Other' when nothing matches this hierarchy
|
| 459 |
+
def _any_present(values: list[str]) -> bool:
|
| 460 |
+
return any(v.lower() in content_lower for v in values)
|
| 461 |
+
|
| 462 |
+
# If no level1 value appears, set to 'Other'
|
| 463 |
+
if not _any_present(hierarchy_def['levels']['level1']['values']):
|
| 464 |
+
level1 = 'Other'
|
| 465 |
+
# If level2 options for chosen level1 exist but none appear, set to 'Other'
|
| 466 |
+
l2_opts = hierarchy_def['levels']['level2']['values'].get(level1, [])
|
| 467 |
+
if l2_opts and not _any_present(l2_opts):
|
| 468 |
+
level2 = 'Other'
|
| 469 |
+
# If level3 options for chosen level2 exist but none appear, set to 'Other'
|
| 470 |
+
l3_opts = hierarchy_def['levels']['level3']['values'].get(level2, [])
|
| 471 |
+
if l3_opts and not _any_present(l3_opts):
|
| 472 |
+
level3 = 'Other'
|
| 473 |
+
|
| 474 |
+
return {
|
| 475 |
+
"source_name": os.path.basename(file_path),
|
| 476 |
+
"lang": language,
|
| 477 |
+
"level1": level1,
|
| 478 |
+
"level2": level2,
|
| 479 |
+
"level3": level3,
|
| 480 |
+
"doc_type": doc_type,
|
| 481 |
+
"chunk_size": len(content),
|
| 482 |
+
"token_count": self.text_processor.count_tokens(content)
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
def _classify_level1(self, content: str, hierarchy_def: Dict[str, Any]) -> str:
|
| 486 |
+
"""Classify level1 domain"""
|
| 487 |
+
level1_options = hierarchy_def['levels']['level1']['values']
|
| 488 |
+
|
| 489 |
+
# Simple keyword matching (enhance with ML model)
|
| 490 |
+
keyword_scores = {}
|
| 491 |
+
for domain in level1_options:
|
| 492 |
+
score = 0
|
| 493 |
+
# Add domain-specific keyword matching logic
|
| 494 |
+
if domain.lower() in content:
|
| 495 |
+
score += 1
|
| 496 |
+
keyword_scores[domain] = score
|
| 497 |
+
|
| 498 |
+
return max(keyword_scores.items(), key=lambda x: x[1])[0] if keyword_scores else level1_options[0]
|
| 499 |
+
|
| 500 |
+
def _classify_level2(self, content: str, hierarchy_def: Dict[str, Any], level1: str) -> str:
|
| 501 |
+
"""Classify level2 section"""
|
| 502 |
+
level2_options = hierarchy_def['levels']['level2']['values'].get(level1, [])
|
| 503 |
+
if not level2_options:
|
| 504 |
+
return "General"
|
| 505 |
+
|
| 506 |
+
keyword_scores = {}
|
| 507 |
+
for section in level2_options:
|
| 508 |
+
score = 0
|
| 509 |
+
if section.lower() in content:
|
| 510 |
+
score += 1
|
| 511 |
+
keyword_scores[section] = score
|
| 512 |
+
|
| 513 |
+
return max(keyword_scores.items(), key=lambda x: x[1])[0] if keyword_scores else level2_options[0]
|
| 514 |
+
|
| 515 |
+
def _classify_level3(self, content: str, hierarchy_def: Dict[str, Any],
|
| 516 |
+
level1: str, level2: str) -> str:
|
| 517 |
+
"""Classify level3 topic"""
|
| 518 |
+
level3_options = hierarchy_def['levels']['level3']['values'].get(level2, [])
|
| 519 |
+
if not level3_options:
|
| 520 |
+
return "General"
|
| 521 |
+
|
| 522 |
+
keyword_scores = {}
|
| 523 |
+
for topic in level3_options:
|
| 524 |
+
score = 0
|
| 525 |
+
if topic.lower() in content:
|
| 526 |
+
score += 1
|
| 527 |
+
keyword_scores[topic] = score
|
| 528 |
+
|
| 529 |
+
return max(keyword_scores.items(), key=lambda x: x[1])[0] if keyword_scores else level3_options[0]
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class FlatTagChunker:
|
| 533 |
+
"""Chunk documents and generate flat, non-hierarchical tags for each chunk."""
|
| 534 |
+
|
| 535 |
+
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200,
|
| 536 |
+
max_tags: int = 10, min_tag_length: int = 2, max_tag_length: int = 3,
|
| 537 |
+
use_openai_for_tags: bool = False):
|
| 538 |
+
"""
|
| 539 |
+
Initialize flat tag chunker.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
chunk_size: Target chunk size in characters
|
| 543 |
+
chunk_overlap: Overlap between chunks in characters
|
| 544 |
+
max_tags: Maximum tags per chunk
|
| 545 |
+
min_tag_length: Minimum words in a tag
|
| 546 |
+
max_tag_length: Maximum words in a tag
|
| 547 |
+
use_openai_for_tags: Whether to use OpenAI for tag generation
|
| 548 |
+
"""
|
| 549 |
+
self.chunk_size = chunk_size
|
| 550 |
+
self.chunk_overlap = chunk_overlap
|
| 551 |
+
self.text_processor = TextProcessor()
|
| 552 |
+
|
| 553 |
+
# Import TagGenerator (lazy import to avoid circular dependencies)
|
| 554 |
+
from .tag_generator import TagGenerator
|
| 555 |
+
self.tag_generator = TagGenerator(
|
| 556 |
+
max_tags=max_tags,
|
| 557 |
+
min_tag_length=min_tag_length,
|
| 558 |
+
max_tag_length=max_tag_length
|
| 559 |
+
)
|
| 560 |
+
self.use_openai_for_tags = use_openai_for_tags
|
| 561 |
+
|
| 562 |
+
def chunk_document(self, file_path: str, language: Optional[str] = None,
|
| 563 |
+
user_tags: Optional[List[str]] = None) -> List[Chunk]:
|
| 564 |
+
"""
|
| 565 |
+
Chunk document and generate flat tags for each chunk.
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
file_path: Path to document file
|
| 569 |
+
language: Language code ('en', 'ja') or None for auto-detect
|
| 570 |
+
user_tags: Optional list of user-provided tags to add to auto-generated tags
|
| 571 |
+
|
| 572 |
+
Returns:
|
| 573 |
+
List of Chunk objects with tags in metadata
|
| 574 |
+
"""
|
| 575 |
+
loader = DocumentLoader()
|
| 576 |
+
content = loader.load_document(file_path)
|
| 577 |
+
|
| 578 |
+
# Auto-detect language if needed
|
| 579 |
+
if not language or str(language).lower() == 'auto':
|
| 580 |
+
# Simple heuristic
|
| 581 |
+
language = 'ja' if any('\u3040' <= ch <= '\u30ff' or '\u4e00' <= ch <= '\u9faf' for ch in content) else 'en'
|
| 582 |
+
|
| 583 |
+
# Set language for tag generator
|
| 584 |
+
self.tag_generator.language = language
|
| 585 |
+
|
| 586 |
+
# Normalize user tags (lowercase, strip, remove empty)
|
| 587 |
+
normalized_user_tags = []
|
| 588 |
+
if user_tags:
|
| 589 |
+
for tag in user_tags:
|
| 590 |
+
if isinstance(tag, str) and tag.strip():
|
| 591 |
+
normalized_user_tags.append(tag.strip().lower())
|
| 592 |
+
|
| 593 |
+
# Simple structural chunking: split on double newlines first
|
| 594 |
+
raw_blocks = [b.strip() for b in content.split('\n\n') if b.strip()]
|
| 595 |
+
if not raw_blocks:
|
| 596 |
+
raw_blocks = [content]
|
| 597 |
+
|
| 598 |
+
# Further split large blocks into chunks
|
| 599 |
+
chunks = []
|
| 600 |
+
for block in raw_blocks:
|
| 601 |
+
if len(block) <= self.chunk_size:
|
| 602 |
+
chunks.append(block)
|
| 603 |
+
else:
|
| 604 |
+
# Split by sentences, then combine into chunks
|
| 605 |
+
sentences = self.text_processor.split_sentences(block)
|
| 606 |
+
current_chunk = ""
|
| 607 |
+
for sentence in sentences:
|
| 608 |
+
if len(current_chunk) + len(sentence) <= self.chunk_size:
|
| 609 |
+
current_chunk += sentence + " "
|
| 610 |
+
else:
|
| 611 |
+
if current_chunk:
|
| 612 |
+
chunks.append(current_chunk.strip())
|
| 613 |
+
current_chunk = sentence + " "
|
| 614 |
+
if current_chunk:
|
| 615 |
+
chunks.append(current_chunk.strip())
|
| 616 |
+
|
| 617 |
+
# Generate tags and create Chunk objects
|
| 618 |
+
result_chunks = []
|
| 619 |
+
source_name = os.path.basename(file_path)
|
| 620 |
+
doc_id = generate_id() # Generate one doc_id for all chunks from this document
|
| 621 |
+
|
| 622 |
+
for i, chunk_text in enumerate(chunks):
|
| 623 |
+
# Generate auto tags for this chunk
|
| 624 |
+
auto_tags = self.tag_generator.generate_tags(
|
| 625 |
+
chunk_text,
|
| 626 |
+
methods=['all'] if not self.use_openai_for_tags else ['yake', 'openai'],
|
| 627 |
+
use_openai=self.use_openai_for_tags
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
# Merge user tags with auto-generated tags
|
| 631 |
+
# User tags are prepended (higher priority) and deduplicated
|
| 632 |
+
all_tags = []
|
| 633 |
+
seen_tags = set()
|
| 634 |
+
|
| 635 |
+
# Add user tags first (they get priority)
|
| 636 |
+
for tag in normalized_user_tags:
|
| 637 |
+
if tag not in seen_tags:
|
| 638 |
+
all_tags.append(tag)
|
| 639 |
+
seen_tags.add(tag)
|
| 640 |
+
|
| 641 |
+
# Add auto-generated tags (skip duplicates)
|
| 642 |
+
for tag in auto_tags:
|
| 643 |
+
if tag not in seen_tags:
|
| 644 |
+
all_tags.append(tag)
|
| 645 |
+
seen_tags.add(tag)
|
| 646 |
+
|
| 647 |
+
# Create chunk metadata
|
| 648 |
+
metadata = {
|
| 649 |
+
'source_name': source_name,
|
| 650 |
+
'chunk_index': i,
|
| 651 |
+
'chunk_size': len(chunk_text),
|
| 652 |
+
'lang': language,
|
| 653 |
+
'tags': all_tags # Store as list (user tags + auto tags)
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
# Create Chunk object
|
| 657 |
+
chunk = Chunk(
|
| 658 |
+
doc_id=doc_id,
|
| 659 |
+
chunk_id=generate_id(),
|
| 660 |
+
content=chunk_text,
|
| 661 |
+
metadata=metadata
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
result_chunks.append(chunk)
|
| 665 |
+
|
| 666 |
+
_logger.info(f"Generated {len(result_chunks)} chunks with tags from {source_name}")
|
| 667 |
+
return result_chunks
|
core/report_generator.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Report Generation Module for RAG Evaluation Results
|
| 3 |
+
|
| 4 |
+
This module provides comprehensive report generation functionality including:
|
| 5 |
+
- HTML/PDF report generation with aggregated statistics
|
| 6 |
+
- Representative examples (best/worst performing queries)
|
| 7 |
+
- Visualization embedding
|
| 8 |
+
- Export functionality (CSV, PNG, HTML, PDF)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import numpy as np
|
| 13 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
_logger = logging.getLogger("rag_report_generator")
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from jinja2 import Template
|
| 24 |
+
_JINJA2_AVAILABLE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
_JINJA2_AVAILABLE = False
|
| 27 |
+
_logger.warning("Jinja2 not available. HTML report generation will be limited.")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ReportGenerator:
|
| 31 |
+
"""Generate comprehensive evaluation reports"""
|
| 32 |
+
|
| 33 |
+
METHOD_NAMES = {
|
| 34 |
+
'base_rag': 'Baseline',
|
| 35 |
+
'tag_filter_rag': '+Tags(Filter)',
|
| 36 |
+
'hybrid_rag': 'Hybrid(Weighted)',
|
| 37 |
+
'hybrid_rerank_rag': 'Hybrid+Rerank',
|
| 38 |
+
'hier_rag': 'Hierarchical RAG' # For backward compatibility
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
def __init__(self):
|
| 42 |
+
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 43 |
+
|
| 44 |
+
def generate_report(self,
|
| 45 |
+
df: pd.DataFrame,
|
| 46 |
+
summary: Dict[str, Any],
|
| 47 |
+
output_dir: str = "reports",
|
| 48 |
+
report_name: Optional[str] = None,
|
| 49 |
+
visualizations: Optional[Dict[str, str]] = None,
|
| 50 |
+
raw_results: Optional[List[Dict[str, Any]]] = None) -> Dict[str, str]:
|
| 51 |
+
"""
|
| 52 |
+
Generate comprehensive evaluation report
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
df: Evaluation results DataFrame
|
| 56 |
+
summary: Summary statistics dictionary
|
| 57 |
+
output_dir: Output directory for reports
|
| 58 |
+
report_name: Base name for report files
|
| 59 |
+
visualizations: Dict mapping chart name to file path
|
| 60 |
+
raw_results: Raw evaluation results for examples
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Dict mapping file type to file path
|
| 64 |
+
"""
|
| 65 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
if report_name is None:
|
| 68 |
+
report_name = f"evaluation_report_{self.timestamp}"
|
| 69 |
+
|
| 70 |
+
report_paths = {}
|
| 71 |
+
|
| 72 |
+
# 1. Export CSV
|
| 73 |
+
csv_path = os.path.join(output_dir, f"{report_name}.csv")
|
| 74 |
+
df.to_csv(csv_path, index=False)
|
| 75 |
+
report_paths['csv'] = csv_path
|
| 76 |
+
|
| 77 |
+
# 2. Export JSON summary
|
| 78 |
+
json_path = os.path.join(output_dir, f"{report_name}_summary.json")
|
| 79 |
+
with open(json_path, 'w') as f:
|
| 80 |
+
json.dump(summary, f, indent=2, default=str)
|
| 81 |
+
report_paths['json'] = json_path
|
| 82 |
+
|
| 83 |
+
# 3. Export aggregated statistics CSV
|
| 84 |
+
agg_stats_path = os.path.join(output_dir, f"{report_name}_aggregated_stats.csv")
|
| 85 |
+
agg_stats_df = self._generate_aggregated_stats_table(df, summary)
|
| 86 |
+
agg_stats_df.to_csv(agg_stats_path, index=False)
|
| 87 |
+
report_paths['aggregated_csv'] = agg_stats_path
|
| 88 |
+
|
| 89 |
+
# 4. Export representative examples
|
| 90 |
+
if raw_results:
|
| 91 |
+
examples_path = os.path.join(output_dir, f"{report_name}_examples.json")
|
| 92 |
+
examples = self._extract_representative_examples(df, raw_results)
|
| 93 |
+
with open(examples_path, 'w') as f:
|
| 94 |
+
json.dump(examples, f, indent=2, default=str)
|
| 95 |
+
report_paths['examples'] = examples_path
|
| 96 |
+
|
| 97 |
+
# 5. Export visualizations (if not already exported)
|
| 98 |
+
if visualizations:
|
| 99 |
+
viz_dir = os.path.join(output_dir, "visualizations")
|
| 100 |
+
os.makedirs(viz_dir, exist_ok=True)
|
| 101 |
+
for chart_name, chart_path in visualizations.items():
|
| 102 |
+
if chart_path and os.path.exists(chart_path):
|
| 103 |
+
# Copy to reports directory if not already there
|
| 104 |
+
dest_path = os.path.join(viz_dir, f"{report_name}_{chart_name}.png")
|
| 105 |
+
if not os.path.abspath(chart_path) == os.path.abspath(dest_path):
|
| 106 |
+
import shutil
|
| 107 |
+
shutil.copy2(chart_path, dest_path)
|
| 108 |
+
report_paths[f'viz_{chart_name}'] = dest_path
|
| 109 |
+
|
| 110 |
+
# 6. Generate HTML report
|
| 111 |
+
html_path = os.path.join(output_dir, f"{report_name}.html")
|
| 112 |
+
html_content = self._generate_html_report(
|
| 113 |
+
df, summary, report_name, visualizations or {}, raw_results
|
| 114 |
+
)
|
| 115 |
+
with open(html_path, 'w', encoding='utf-8') as f:
|
| 116 |
+
f.write(html_content)
|
| 117 |
+
report_paths['html'] = html_path
|
| 118 |
+
|
| 119 |
+
return report_paths
|
| 120 |
+
|
| 121 |
+
def _generate_aggregated_stats_table(self, df: pd.DataFrame,
|
| 122 |
+
summary: Dict[str, Any]) -> pd.DataFrame:
|
| 123 |
+
"""Generate aggregated statistics table"""
|
| 124 |
+
rows = []
|
| 125 |
+
|
| 126 |
+
# Aggregate by pipeline and k
|
| 127 |
+
for pipeline in df['pipeline'].unique():
|
| 128 |
+
for k in df['k'].unique():
|
| 129 |
+
pipeline_k_df = df[(df['pipeline'] == pipeline) & (df['k'] == k)]
|
| 130 |
+
|
| 131 |
+
if pipeline_k_df.empty:
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
rows.append({
|
| 135 |
+
'pipeline': self.METHOD_NAMES.get(pipeline, pipeline),
|
| 136 |
+
'k': k,
|
| 137 |
+
'mean_precision@k': pipeline_k_df['precision_at_k'].mean(),
|
| 138 |
+
'std_precision@k': pipeline_k_df['precision_at_k'].std(),
|
| 139 |
+
'mean_ndcg@k': pipeline_k_df['ndcg_at_k'].mean(),
|
| 140 |
+
'std_ndcg@k': pipeline_k_df['ndcg_at_k'].std(),
|
| 141 |
+
'mean_hit@k': pipeline_k_df['hit_at_k'].mean(),
|
| 142 |
+
'mean_mrr': pipeline_k_df['mrr'].mean(),
|
| 143 |
+
'mean_semantic_similarity': pipeline_k_df['semantic_similarity'].mean(),
|
| 144 |
+
'mean_latency': pipeline_k_df['latency'].mean(),
|
| 145 |
+
'p50_latency': pipeline_k_df['latency'].quantile(0.5),
|
| 146 |
+
'p90_latency': pipeline_k_df['latency'].quantile(0.9),
|
| 147 |
+
'query_count': len(pipeline_k_df['query_id'].unique()),
|
| 148 |
+
'mean_user_satisfaction': pipeline_k_df['user_satisfaction'].mean() if 'user_satisfaction' in pipeline_k_df.columns else None
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
return pd.DataFrame(rows)
|
| 152 |
+
|
| 153 |
+
def _extract_representative_examples(self,
|
| 154 |
+
df: pd.DataFrame,
|
| 155 |
+
raw_results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 156 |
+
"""Extract representative examples: best and worst performing queries"""
|
| 157 |
+
examples = {
|
| 158 |
+
'best_performing': [],
|
| 159 |
+
'worst_performing': [],
|
| 160 |
+
'most_improved': []
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# Best performing queries (by precision@5)
|
| 164 |
+
best_df = df[df['k'] == 5].nlargest(3, 'precision_at_k')
|
| 165 |
+
for _, row in best_df.iterrows():
|
| 166 |
+
query_id = row['query_id']
|
| 167 |
+
query_result = next((r for r in raw_results if r.get('query_id') == query_id), None)
|
| 168 |
+
if query_result:
|
| 169 |
+
examples['best_performing'].append({
|
| 170 |
+
'query': query_result['query'],
|
| 171 |
+
'pipeline': row['pipeline'],
|
| 172 |
+
'precision_at_k': row['precision_at_k'],
|
| 173 |
+
'ndcg_at_k': row['ndcg_at_k'],
|
| 174 |
+
'mrr': row['mrr']
|
| 175 |
+
})
|
| 176 |
+
|
| 177 |
+
# Worst performing queries
|
| 178 |
+
worst_df = df[df['k'] == 5].nsmallest(3, 'precision_at_k')
|
| 179 |
+
for _, row in worst_df.iterrows():
|
| 180 |
+
query_id = row['query_id']
|
| 181 |
+
query_result = next((r for r in raw_results if r.get('query_id') == query_id), None)
|
| 182 |
+
if query_result:
|
| 183 |
+
examples['worst_performing'].append({
|
| 184 |
+
'query': query_result['query'],
|
| 185 |
+
'pipeline': row['pipeline'],
|
| 186 |
+
'precision_at_k': row['precision_at_k'],
|
| 187 |
+
'ndcg_at_k': row['ndcg_at_k'],
|
| 188 |
+
'mrr': row['mrr']
|
| 189 |
+
})
|
| 190 |
+
|
| 191 |
+
# Most improved (hybrid vs baseline)
|
| 192 |
+
if 'hybrid_rag' in df['pipeline'].values and 'base_rag' in df['pipeline'].values:
|
| 193 |
+
baseline_df = df[(df['pipeline'] == 'base_rag') & (df['k'] == 5)].set_index('query_id')
|
| 194 |
+
hybrid_df = df[(df['pipeline'] == 'hybrid_rag') & (df['k'] == 5)].set_index('query_id')
|
| 195 |
+
|
| 196 |
+
common_ids = baseline_df.index.intersection(hybrid_df.index)
|
| 197 |
+
if len(common_ids) > 0:
|
| 198 |
+
improvement = (hybrid_df.loc[common_ids, 'precision_at_k'] -
|
| 199 |
+
baseline_df.loc[common_ids, 'precision_at_k']).nlargest(3)
|
| 200 |
+
for query_id in improvement.index:
|
| 201 |
+
query_result = next((r for r in raw_results if r.get('query_id') == query_id), None)
|
| 202 |
+
if query_result:
|
| 203 |
+
examples['most_improved'].append({
|
| 204 |
+
'query': query_result['query'],
|
| 205 |
+
'baseline_precision': baseline_df.loc[query_id, 'precision_at_k'],
|
| 206 |
+
'hybrid_precision': hybrid_df.loc[query_id, 'precision_at_k'],
|
| 207 |
+
'improvement': improvement[query_id]
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
return examples
|
| 211 |
+
|
| 212 |
+
def _generate_html_report(self,
|
| 213 |
+
df: pd.DataFrame,
|
| 214 |
+
summary: Dict[str, Any],
|
| 215 |
+
report_name: str,
|
| 216 |
+
visualizations: Dict[str, str],
|
| 217 |
+
raw_results: Optional[List[Dict[str, Any]]]) -> str:
|
| 218 |
+
"""Generate HTML report with all statistics and visualizations"""
|
| 219 |
+
|
| 220 |
+
# Generate aggregated stats table
|
| 221 |
+
agg_stats_df = self._generate_aggregated_stats_table(df, summary)
|
| 222 |
+
agg_stats_html = agg_stats_df.to_html(classes='table table-striped', table_id='aggregated_stats', escape=False)
|
| 223 |
+
|
| 224 |
+
# Generate summary statistics
|
| 225 |
+
summary_html = self._format_summary_html(summary)
|
| 226 |
+
|
| 227 |
+
# Generate representative examples
|
| 228 |
+
examples_html = ""
|
| 229 |
+
if raw_results:
|
| 230 |
+
examples = self._extract_representative_examples(df, raw_results)
|
| 231 |
+
examples_html = self._format_examples_html(examples)
|
| 232 |
+
|
| 233 |
+
# Generate visualization HTML
|
| 234 |
+
viz_html = self._format_visualizations_html(visualizations, report_name)
|
| 235 |
+
|
| 236 |
+
# Generate insights
|
| 237 |
+
insights_html = self._generate_insights(df, summary)
|
| 238 |
+
|
| 239 |
+
# Simple HTML template (without Jinja2 for compatibility)
|
| 240 |
+
html_template = f"""<!DOCTYPE html>
|
| 241 |
+
<html lang="en">
|
| 242 |
+
<head>
|
| 243 |
+
<meta charset="UTF-8">
|
| 244 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 245 |
+
<title>RAG Evaluation Report - {report_name}</title>
|
| 246 |
+
<style>
|
| 247 |
+
body {{
|
| 248 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 249 |
+
line-height: 1.6;
|
| 250 |
+
color: #333;
|
| 251 |
+
max-width: 1200px;
|
| 252 |
+
margin: 0 auto;
|
| 253 |
+
padding: 20px;
|
| 254 |
+
background-color: #f5f5f5;
|
| 255 |
+
}}
|
| 256 |
+
h1 {{ color: #2c3e50; border-bottom: 3px solid #3498db; padding-bottom: 10px; }}
|
| 257 |
+
h2 {{ color: #34495e; margin-top: 30px; border-left: 4px solid #3498db; padding-left: 10px; }}
|
| 258 |
+
table {{ width: 100%; border-collapse: collapse; margin: 20px 0; background-color: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
|
| 259 |
+
th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }}
|
| 260 |
+
th {{ background-color: #3498db; color: white; font-weight: bold; }}
|
| 261 |
+
tr:hover {{ background-color: #f5f5f5; }}
|
| 262 |
+
.insights {{ background-color: #fff3cd; border-left: 4px solid #ffc107; padding: 15px; margin: 20px 0; border-radius: 5px; }}
|
| 263 |
+
.examples {{ background-color: #e7f3ff; border-left: 4px solid #2196F3; padding: 15px; margin: 20px 0; border-radius: 5px; }}
|
| 264 |
+
.viz-container {{ text-align: center; margin: 30px 0; background-color: white; padding: 20px; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
|
| 265 |
+
.viz-container img {{ max-width: 100%; height: auto; border: 1px solid #ddd; border-radius: 3px; }}
|
| 266 |
+
</style>
|
| 267 |
+
</head>
|
| 268 |
+
<body>
|
| 269 |
+
<h1>RAG Evaluation Report</h1>
|
| 270 |
+
<div class="metadata">
|
| 271 |
+
<p><strong>Report Name:</strong> {report_name}</p>
|
| 272 |
+
<p><strong>Generated:</strong> {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</p>
|
| 273 |
+
<p><strong>Total Queries:</strong> {len(df['query_id'].unique()) if 'query_id' in df.columns else len(df['query'].unique())}</p>
|
| 274 |
+
<p><strong>Pipelines Evaluated:</strong> {len(df['pipeline'].unique())}</p>
|
| 275 |
+
</div>
|
| 276 |
+
|
| 277 |
+
<h2>Summary Statistics</h2>
|
| 278 |
+
{summary_html}
|
| 279 |
+
|
| 280 |
+
<h2>Aggregated Performance Metrics</h2>
|
| 281 |
+
{agg_stats_html}
|
| 282 |
+
|
| 283 |
+
<h2>Visualizations</h2>
|
| 284 |
+
{viz_html}
|
| 285 |
+
|
| 286 |
+
<h2>Representative Examples</h2>
|
| 287 |
+
{examples_html}
|
| 288 |
+
|
| 289 |
+
<h2>Insights and Recommendations</h2>
|
| 290 |
+
<div class="insights">
|
| 291 |
+
{insights_html}
|
| 292 |
+
</div>
|
| 293 |
+
</body>
|
| 294 |
+
</html>
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
return html_template
|
| 298 |
+
|
| 299 |
+
def _format_summary_html(self, summary: Dict[str, Any]) -> str:
|
| 300 |
+
"""Format summary statistics as HTML"""
|
| 301 |
+
html = "<table><tr><th>Pipeline</th><th>k</th><th>Mean Precision@k</th><th>Mean nDCG@k</th><th>Mean MRR</th><th>Mean Latency (s)</th><th>P50 Latency (s)</th><th>P90 Latency (s)</th></tr>"
|
| 302 |
+
|
| 303 |
+
for pipeline, pipeline_data in summary.items():
|
| 304 |
+
if isinstance(pipeline_data, dict) and 'latency_percentiles' not in pipeline_data:
|
| 305 |
+
for k, metrics in pipeline_data.items():
|
| 306 |
+
if isinstance(k, int):
|
| 307 |
+
html += f"<tr>"
|
| 308 |
+
html += f"<td>{self.METHOD_NAMES.get(pipeline, pipeline)}</td>"
|
| 309 |
+
html += f"<td>{k}</td>"
|
| 310 |
+
html += f"<td>{metrics.get('mean_precision_at_k', 0):.3f}</td>"
|
| 311 |
+
html += f"<td>{metrics.get('mean_ndcg_at_k', 0):.3f}</td>"
|
| 312 |
+
html += f"<td>{metrics.get('mean_mrr', 0):.3f}</td>"
|
| 313 |
+
html += f"<td>{metrics.get('mean_latency', 0):.3f}</td>"
|
| 314 |
+
html += f"<td>{metrics.get('p50_latency', 0):.3f}</td>"
|
| 315 |
+
html += f"<td>{metrics.get('p90_latency', 0):.3f}</td>"
|
| 316 |
+
html += f"</tr>"
|
| 317 |
+
|
| 318 |
+
html += "</table>"
|
| 319 |
+
return html
|
| 320 |
+
|
| 321 |
+
def _format_examples_html(self, examples: Dict[str, Any]) -> str:
|
| 322 |
+
"""Format representative examples as HTML"""
|
| 323 |
+
html = ""
|
| 324 |
+
|
| 325 |
+
if examples.get('best_performing'):
|
| 326 |
+
html += "<h3>Best Performing Queries</h3>"
|
| 327 |
+
for example in examples['best_performing']:
|
| 328 |
+
html += f"<div class='example-item'><strong>Query:</strong> {example['query']}<br>"
|
| 329 |
+
html += f"<strong>Pipeline:</strong> {self.METHOD_NAMES.get(example['pipeline'], example['pipeline'])}<br>"
|
| 330 |
+
html += f"<strong>Precision@5:</strong> {example['precision_at_k']:.3f}</div>"
|
| 331 |
+
|
| 332 |
+
if examples.get('worst_performing'):
|
| 333 |
+
html += "<h3>Worst Performing Queries</h3>"
|
| 334 |
+
for example in examples['worst_performing']:
|
| 335 |
+
html += f"<div class='example-item'><strong>Query:</strong> {example['query']}<br>"
|
| 336 |
+
html += f"<strong>Pipeline:</strong> {self.METHOD_NAMES.get(example['pipeline'], example['pipeline'])}<br>"
|
| 337 |
+
html += f"<strong>Precision@5:</strong> {example['precision_at_k']:.3f}</div>"
|
| 338 |
+
|
| 339 |
+
if examples.get('most_improved'):
|
| 340 |
+
html += "<h3>Most Improved Queries (Hybrid vs Baseline)</h3>"
|
| 341 |
+
for example in examples['most_improved']:
|
| 342 |
+
html += f"<div class='example-item'><strong>Query:</strong> {example['query']}<br>"
|
| 343 |
+
html += f"<strong>Improvement:</strong> +{example['improvement']:.3f}</div>"
|
| 344 |
+
|
| 345 |
+
if not html:
|
| 346 |
+
return "<p>No representative examples available.</p>"
|
| 347 |
+
|
| 348 |
+
return f"<div class='examples'>{html}</div>"
|
| 349 |
+
|
| 350 |
+
def _format_visualizations_html(self, visualizations: Dict[str, str], report_name: str) -> str:
|
| 351 |
+
"""Format visualization images as HTML"""
|
| 352 |
+
if not visualizations:
|
| 353 |
+
return "<p>No visualizations available.</p>"
|
| 354 |
+
|
| 355 |
+
html = ""
|
| 356 |
+
for chart_key, chart_path in visualizations.items():
|
| 357 |
+
if chart_path and os.path.exists(chart_path):
|
| 358 |
+
rel_path = os.path.relpath(chart_path, os.path.dirname(chart_path))
|
| 359 |
+
html += f"<div class='viz-container'><img src='{rel_path}' alt='{chart_key}'></div>"
|
| 360 |
+
|
| 361 |
+
return html if html else "<p>No visualizations could be loaded.</p>"
|
| 362 |
+
|
| 363 |
+
def _generate_insights(self, df: pd.DataFrame, summary: Dict[str, Any]) -> str:
|
| 364 |
+
"""Generate insights and recommendations from the evaluation results"""
|
| 365 |
+
insights = []
|
| 366 |
+
|
| 367 |
+
# Find best performing pipeline
|
| 368 |
+
if 'k' in df.columns:
|
| 369 |
+
k5_df = df[df['k'] == 5]
|
| 370 |
+
if len(k5_df) > 0:
|
| 371 |
+
best_pipeline = k5_df.groupby('pipeline')['precision_at_k'].mean().idxmax()
|
| 372 |
+
best_precision = k5_df.groupby('pipeline')['precision_at_k'].mean().max()
|
| 373 |
+
insights.append(f"<li><strong>Best Performing Pipeline:</strong> {self.METHOD_NAMES.get(best_pipeline, best_pipeline)} with average Precision@5 of {best_precision:.3f}</li>")
|
| 374 |
+
|
| 375 |
+
# Latency analysis
|
| 376 |
+
if 'latency' in df.columns:
|
| 377 |
+
avg_latency = df['latency'].mean()
|
| 378 |
+
if avg_latency < 0.5:
|
| 379 |
+
insights.append("<li><strong>Latency:</strong> System response time is excellent (<0.5s)</li>")
|
| 380 |
+
elif avg_latency < 1.0:
|
| 381 |
+
insights.append("<li><strong>Latency:</strong> System response time is good (<1.0s)</li>")
|
| 382 |
+
else:
|
| 383 |
+
insights.append(f"<li><strong>Latency:</strong> System response time may need optimization (avg: {avg_latency:.2f}s)</li>")
|
| 384 |
+
|
| 385 |
+
return f"<ul>{''.join(insights)}</ul>" if insights else "<p>No specific insights available.</p>"
|
| 386 |
+
|
core/reranker.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reranking Module for RAG Retrieval
|
| 3 |
+
|
| 4 |
+
This module provides reranking functionality to reorder retrieved documents
|
| 5 |
+
based on their relevance to the query using cross-encoder models or semantic similarity.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
_logger = logging.getLogger("rag_reranker")
|
| 13 |
+
|
| 14 |
+
# Try to import optional dependencies
|
| 15 |
+
try:
|
| 16 |
+
from sentence_transformers import CrossEncoder
|
| 17 |
+
CROSSENCODER_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
CROSSENCODER_AVAILABLE = False
|
| 20 |
+
_logger.debug("sentence-transformers CrossEncoder not available")
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from sentence_transformers import SentenceTransformer
|
| 24 |
+
SENTENCETRANSFORMERS_AVAILABLE = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
SENTENCETRANSFORMERS_AVAILABLE = False
|
| 27 |
+
_logger.debug("SentenceTransformers not available")
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 31 |
+
SKLEARN_AVAILABLE = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
SKLEARN_AVAILABLE = False
|
| 34 |
+
_logger.debug("scikit-learn not available")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Reranker:
|
| 38 |
+
"""
|
| 39 |
+
Reranker for retrieved documents.
|
| 40 |
+
|
| 41 |
+
Supports multiple reranking strategies:
|
| 42 |
+
- Cross-encoder models (best quality)
|
| 43 |
+
- Semantic similarity (fallback)
|
| 44 |
+
- Heuristic scoring (last resort)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, model_name: Optional[str] = None):
|
| 48 |
+
"""
|
| 49 |
+
Initialize reranker.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
model_name: Cross-encoder model name (default: cross-encoder/ms-marco-MiniLM-L-6-v2)
|
| 53 |
+
"""
|
| 54 |
+
self.model_name = model_name or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 55 |
+
self._cross_encoder = None
|
| 56 |
+
self._embedding_model = None
|
| 57 |
+
|
| 58 |
+
def _initialize_cross_encoder(self):
|
| 59 |
+
"""Initialize cross-encoder model."""
|
| 60 |
+
if not CROSSENCODER_AVAILABLE:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
if self._cross_encoder is None:
|
| 64 |
+
try:
|
| 65 |
+
self._cross_encoder = CrossEncoder(self.model_name)
|
| 66 |
+
_logger.info(f"Initialized CrossEncoder: {self.model_name}")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
_logger.warning(f"Failed to initialize CrossEncoder: {e}")
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
return self._cross_encoder
|
| 72 |
+
|
| 73 |
+
def _initialize_embedding_model(self):
|
| 74 |
+
"""Initialize embedding model for semantic similarity."""
|
| 75 |
+
if not SENTENCETRANSFORMERS_AVAILABLE:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
if self._embedding_model is None:
|
| 79 |
+
try:
|
| 80 |
+
# Use a lightweight model for reranking
|
| 81 |
+
self._embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 82 |
+
except Exception as e:
|
| 83 |
+
_logger.warning(f"Failed to initialize embedding model: {e}")
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
return self._embedding_model
|
| 87 |
+
|
| 88 |
+
def rerank(self, query: str, results: List[Dict[str, Any]], top_k: Optional[int] = None) -> List[Dict[str, Any]]:
|
| 89 |
+
"""
|
| 90 |
+
Rerank retrieved documents based on query relevance.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
query: Search query
|
| 94 |
+
results: List of retrieved documents with content and metadata
|
| 95 |
+
top_k: Number of top results to return (None for all)
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Reranked list of documents
|
| 99 |
+
"""
|
| 100 |
+
if not results:
|
| 101 |
+
return []
|
| 102 |
+
|
| 103 |
+
# Try cross-encoder reranking first
|
| 104 |
+
reranked = self._rerank_with_cross_encoder(query, results)
|
| 105 |
+
if reranked:
|
| 106 |
+
if top_k:
|
| 107 |
+
return reranked[:top_k]
|
| 108 |
+
return reranked
|
| 109 |
+
|
| 110 |
+
# Fallback to semantic similarity
|
| 111 |
+
reranked = self._rerank_with_similarity(query, results)
|
| 112 |
+
if reranked:
|
| 113 |
+
if top_k:
|
| 114 |
+
return reranked[:top_k]
|
| 115 |
+
return reranked
|
| 116 |
+
|
| 117 |
+
# Last resort: return original order
|
| 118 |
+
if top_k:
|
| 119 |
+
return results[:top_k]
|
| 120 |
+
return results
|
| 121 |
+
|
| 122 |
+
def _rerank_with_cross_encoder(self, query: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 123 |
+
"""Rerank using cross-encoder model."""
|
| 124 |
+
model = self._initialize_cross_encoder()
|
| 125 |
+
if not model:
|
| 126 |
+
return []
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
# Prepare query-document pairs
|
| 130 |
+
pairs = [(query, result['content']) for result in results]
|
| 131 |
+
|
| 132 |
+
# Get scores from cross-encoder
|
| 133 |
+
scores = model.predict(pairs)
|
| 134 |
+
|
| 135 |
+
# Add rerank scores to results
|
| 136 |
+
for i, result in enumerate(results):
|
| 137 |
+
result['rerank_score'] = float(scores[i])
|
| 138 |
+
|
| 139 |
+
# Sort by rerank score (descending)
|
| 140 |
+
reranked = sorted(results, key=lambda x: x.get('rerank_score', 0.0), reverse=True)
|
| 141 |
+
|
| 142 |
+
return reranked
|
| 143 |
+
except Exception as e:
|
| 144 |
+
_logger.warning(f"Cross-encoder reranking failed: {e}")
|
| 145 |
+
return []
|
| 146 |
+
|
| 147 |
+
def _rerank_with_similarity(self, query: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 148 |
+
"""Rerank using semantic similarity."""
|
| 149 |
+
model = self._initialize_embedding_model()
|
| 150 |
+
if not model:
|
| 151 |
+
return []
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
# Encode query
|
| 155 |
+
query_embedding = model.encode(query, convert_to_numpy=True)
|
| 156 |
+
|
| 157 |
+
# Encode all documents
|
| 158 |
+
documents = [result['content'] for result in results]
|
| 159 |
+
doc_embeddings = model.encode(documents, convert_to_numpy=True)
|
| 160 |
+
|
| 161 |
+
# Calculate cosine similarities
|
| 162 |
+
if SKLEARN_AVAILABLE:
|
| 163 |
+
similarities = cosine_similarity([query_embedding], doc_embeddings)[0]
|
| 164 |
+
else:
|
| 165 |
+
# Manual cosine similarity
|
| 166 |
+
similarities = np.array([
|
| 167 |
+
np.dot(query_embedding, doc_emb) /
|
| 168 |
+
(np.linalg.norm(query_embedding) * np.linalg.norm(doc_emb))
|
| 169 |
+
for doc_emb in doc_embeddings
|
| 170 |
+
])
|
| 171 |
+
|
| 172 |
+
# Add similarity scores to results
|
| 173 |
+
for i, result in enumerate(results):
|
| 174 |
+
result['rerank_score'] = float(similarities[i])
|
| 175 |
+
|
| 176 |
+
# Sort by similarity score (descending)
|
| 177 |
+
reranked = sorted(results, key=lambda x: x.get('rerank_score', 0.0), reverse=True)
|
| 178 |
+
|
| 179 |
+
return reranked
|
| 180 |
+
except Exception as e:
|
| 181 |
+
_logger.warning(f"Similarity reranking failed: {e}")
|
| 182 |
+
return []
|
| 183 |
+
|
core/retrieval.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import List, Dict, Any, Optional, Tuple, TYPE_CHECKING
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from .index import VectorStore
|
| 5 |
+
|
| 6 |
+
if TYPE_CHECKING:
|
| 7 |
+
from .reranker import Reranker
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class RetrievalResult:
|
| 11 |
+
"""Result from retrieval pipeline"""
|
| 12 |
+
content: str
|
| 13 |
+
sources: List[Dict[str, Any]]
|
| 14 |
+
latency: float
|
| 15 |
+
metadata: Dict[str, Any]
|
| 16 |
+
|
| 17 |
+
class BaseRAG:
|
| 18 |
+
"""Standard RAG pipeline without hierarchical filtering"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, vector_store: VectorStore, collection_name: str = "documents"):
|
| 21 |
+
self.vector_store = vector_store
|
| 22 |
+
self.collection_name = collection_name
|
| 23 |
+
|
| 24 |
+
def retrieve(self, query: str, k: int = 5) -> RetrievalResult:
|
| 25 |
+
"""Retrieve documents using standard vector similarity"""
|
| 26 |
+
start_time = time.time()
|
| 27 |
+
|
| 28 |
+
results = self.vector_store.search(
|
| 29 |
+
collection_name=self.collection_name,
|
| 30 |
+
query=query,
|
| 31 |
+
k=k
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
latency = time.time() - start_time
|
| 35 |
+
|
| 36 |
+
return RetrievalResult(
|
| 37 |
+
content=self._format_results(results),
|
| 38 |
+
sources=results,
|
| 39 |
+
latency=latency,
|
| 40 |
+
metadata={"strategy": "base_rag", "k": k}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
| 44 |
+
"""Format retrieval results into text"""
|
| 45 |
+
formatted = []
|
| 46 |
+
for i, result in enumerate(results, 1):
|
| 47 |
+
formatted.append(f"[{i}] {result['content'][:200]}... (Score: {result['score']:.3f})")
|
| 48 |
+
return "\n\n".join(formatted)
|
| 49 |
+
|
| 50 |
+
class HierarchicalRAG:
|
| 51 |
+
"""Hierarchical RAG pipeline with metadata filtering"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, vector_store: VectorStore, collection_name: str = "documents"):
|
| 54 |
+
self.vector_store = vector_store
|
| 55 |
+
self.collection_name = collection_name
|
| 56 |
+
|
| 57 |
+
def retrieve(self, query: str, k: int = 5,
|
| 58 |
+
level1: Optional[str] = None,
|
| 59 |
+
level2: Optional[str] = None,
|
| 60 |
+
level3: Optional[str] = None,
|
| 61 |
+
doc_type: Optional[str] = None) -> RetrievalResult:
|
| 62 |
+
"""Retrieve documents with hierarchical filtering"""
|
| 63 |
+
start_time = time.time()
|
| 64 |
+
|
| 65 |
+
# Build metadata filters
|
| 66 |
+
filters = self._build_filters(level1, level2, level3, doc_type)
|
| 67 |
+
|
| 68 |
+
results = self.vector_store.search(
|
| 69 |
+
collection_name=self.collection_name,
|
| 70 |
+
query=query,
|
| 71 |
+
filters=filters,
|
| 72 |
+
k=k
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
latency = time.time() - start_time
|
| 76 |
+
|
| 77 |
+
return RetrievalResult(
|
| 78 |
+
content=self._format_results(results),
|
| 79 |
+
sources=results,
|
| 80 |
+
latency=latency,
|
| 81 |
+
metadata={
|
| 82 |
+
"strategy": "hier_rag",
|
| 83 |
+
"k": k,
|
| 84 |
+
"filters": filters
|
| 85 |
+
}
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def _build_filters(self, level1: Optional[str], level2: Optional[str],
|
| 89 |
+
level3: Optional[str], doc_type: Optional[str]) -> Dict[str, Any]:
|
| 90 |
+
"""Build metadata filters for hierarchical search"""
|
| 91 |
+
filters = {}
|
| 92 |
+
|
| 93 |
+
if level1:
|
| 94 |
+
filters["level1"] = level1
|
| 95 |
+
if level2:
|
| 96 |
+
filters["level2"] = level2
|
| 97 |
+
if level3:
|
| 98 |
+
filters["level3"] = level3
|
| 99 |
+
if doc_type:
|
| 100 |
+
filters["doc_type"] = doc_type
|
| 101 |
+
|
| 102 |
+
return filters if filters else None
|
| 103 |
+
|
| 104 |
+
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
| 105 |
+
"""Format retrieval results into text"""
|
| 106 |
+
formatted = []
|
| 107 |
+
for i, result in enumerate(results, 1):
|
| 108 |
+
metadata = result['metadata']
|
| 109 |
+
formatted.append(
|
| 110 |
+
f"[{i}] {result['content'][:200]}...\n"
|
| 111 |
+
f" Domain: {metadata.get('level1', 'N/A')} > "
|
| 112 |
+
f"{metadata.get('level2', 'N/A')} > "
|
| 113 |
+
f"{metadata.get('level3', 'N/A')}\n"
|
| 114 |
+
f" Score: {result['score']:.3f}"
|
| 115 |
+
)
|
| 116 |
+
return "\n\n".join(formatted)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TagFilterRAG:
|
| 120 |
+
"""RAG pipeline with flat tag-based filtering"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, vector_store: VectorStore, collection_name: str = "documents"):
|
| 123 |
+
self.vector_store = vector_store
|
| 124 |
+
self.collection_name = collection_name
|
| 125 |
+
|
| 126 |
+
def retrieve(self, query: str, k: int = 5,
|
| 127 |
+
tags: Optional[List[str]] = None,
|
| 128 |
+
tag_operator: str = "OR") -> RetrievalResult:
|
| 129 |
+
"""
|
| 130 |
+
Retrieve documents using tag filtering.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
query: Search query
|
| 134 |
+
k: Number of results
|
| 135 |
+
tags: List of tags to filter by
|
| 136 |
+
tag_operator: Tag filter operator - "AND", "OR", or "NOT"
|
| 137 |
+
"""
|
| 138 |
+
start_time = time.time()
|
| 139 |
+
|
| 140 |
+
# Build tag filters
|
| 141 |
+
tag_filters = None
|
| 142 |
+
if tags and len(tags) > 0:
|
| 143 |
+
tag_filters = {
|
| 144 |
+
"tags": tags,
|
| 145 |
+
"operator": tag_operator.upper()
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# Search with tag filters
|
| 149 |
+
results = self.vector_store.search(
|
| 150 |
+
collection_name=self.collection_name,
|
| 151 |
+
query=query,
|
| 152 |
+
filters=None,
|
| 153 |
+
tag_filters=tag_filters,
|
| 154 |
+
k=k
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
latency = time.time() - start_time
|
| 158 |
+
|
| 159 |
+
return RetrievalResult(
|
| 160 |
+
content=self._format_results(results),
|
| 161 |
+
sources=results,
|
| 162 |
+
latency=latency,
|
| 163 |
+
metadata={
|
| 164 |
+
"strategy": "tag_filter_rag",
|
| 165 |
+
"k": k,
|
| 166 |
+
"tags": tags,
|
| 167 |
+
"tag_operator": tag_operator
|
| 168 |
+
}
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
| 172 |
+
"""Format retrieval results into text"""
|
| 173 |
+
formatted = []
|
| 174 |
+
for i, result in enumerate(results, 1):
|
| 175 |
+
metadata = result.get('metadata', {})
|
| 176 |
+
tags = metadata.get('tags', [])
|
| 177 |
+
tags_str = ", ".join(tags[:5]) if isinstance(tags, list) else str(tags)
|
| 178 |
+
if isinstance(tags, list) and len(tags) > 5:
|
| 179 |
+
tags_str += "..."
|
| 180 |
+
formatted.append(
|
| 181 |
+
f"[{i}] {result['content'][:200]}...\n"
|
| 182 |
+
f" Tags: {tags_str}\n"
|
| 183 |
+
f" Score: {result['score']:.3f}"
|
| 184 |
+
)
|
| 185 |
+
return "\n\n".join(formatted)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class HybridRAG:
|
| 189 |
+
"""Hybrid RAG pipeline combining vector search and tag search"""
|
| 190 |
+
|
| 191 |
+
def __init__(self, vector_store: VectorStore, collection_name: str = "documents"):
|
| 192 |
+
self.vector_store = vector_store
|
| 193 |
+
self.collection_name = collection_name
|
| 194 |
+
self.base_rag = BaseRAG(vector_store, collection_name)
|
| 195 |
+
self.tag_filter_rag = TagFilterRAG(vector_store, collection_name)
|
| 196 |
+
|
| 197 |
+
def retrieve(self, query: str, k: int = 5,
|
| 198 |
+
tags: Optional[List[str]] = None,
|
| 199 |
+
tag_operator: str = "OR",
|
| 200 |
+
vector_weight: float = 0.7,
|
| 201 |
+
tag_weight: float = 0.3) -> RetrievalResult:
|
| 202 |
+
"""
|
| 203 |
+
Retrieve documents using hybrid approach: weighted combination of vector and tag search.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
query: Search query text
|
| 207 |
+
k: Number of results to return
|
| 208 |
+
tags: List of tags for tag search (if None, extracts from query)
|
| 209 |
+
tag_operator: Tag filter operator - "AND", "OR", or "NOT"
|
| 210 |
+
vector_weight: Weight for vector similarity score (0.0-1.0)
|
| 211 |
+
tag_weight: Weight for tag matching score (0.0-1.0). Should sum to 1.0 with vector_weight
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
RetrievalResult with hybrid-scored documents
|
| 215 |
+
"""
|
| 216 |
+
start_time = time.time()
|
| 217 |
+
|
| 218 |
+
# Normalize weights
|
| 219 |
+
total_weight = vector_weight + tag_weight
|
| 220 |
+
if total_weight > 0:
|
| 221 |
+
vector_weight = vector_weight / total_weight
|
| 222 |
+
tag_weight = tag_weight / total_weight
|
| 223 |
+
else:
|
| 224 |
+
vector_weight = 0.5
|
| 225 |
+
tag_weight = 0.5
|
| 226 |
+
|
| 227 |
+
# Extract tags from query if not provided
|
| 228 |
+
if tags is None:
|
| 229 |
+
tags = self._extract_tags_from_query(query)
|
| 230 |
+
|
| 231 |
+
# Fetch more results to have enough for merging
|
| 232 |
+
fetch_k = k * 3
|
| 233 |
+
|
| 234 |
+
# Get vector search results
|
| 235 |
+
vector_results = self.base_rag.vector_store.search(
|
| 236 |
+
collection_name=self.collection_name,
|
| 237 |
+
query=query,
|
| 238 |
+
filters=None,
|
| 239 |
+
k=fetch_k
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Get tag search results (only if tags are provided)
|
| 243 |
+
tag_results = []
|
| 244 |
+
if tags and len(tags) > 0:
|
| 245 |
+
tag_filters = {
|
| 246 |
+
"tags": tags,
|
| 247 |
+
"operator": tag_operator.upper()
|
| 248 |
+
}
|
| 249 |
+
tag_results = self.tag_filter_rag.vector_store.search(
|
| 250 |
+
collection_name=self.collection_name,
|
| 251 |
+
query=query,
|
| 252 |
+
filters=None,
|
| 253 |
+
tag_filters=tag_filters,
|
| 254 |
+
k=fetch_k
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Merge and combine scores
|
| 258 |
+
hybrid_results = self._combine_results(
|
| 259 |
+
vector_results, tag_results, vector_weight, tag_weight
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Return top k
|
| 263 |
+
hybrid_results = hybrid_results[:k]
|
| 264 |
+
|
| 265 |
+
latency = time.time() - start_time
|
| 266 |
+
|
| 267 |
+
return RetrievalResult(
|
| 268 |
+
content=self._format_results(hybrid_results),
|
| 269 |
+
sources=hybrid_results,
|
| 270 |
+
latency=latency,
|
| 271 |
+
metadata={
|
| 272 |
+
"strategy": "hybrid_rag",
|
| 273 |
+
"k": k,
|
| 274 |
+
"vector_weight": vector_weight,
|
| 275 |
+
"tag_weight": tag_weight,
|
| 276 |
+
"tags": tags,
|
| 277 |
+
"tag_operator": tag_operator
|
| 278 |
+
}
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def _extract_tags_from_query(self, query: str) -> List[str]:
|
| 282 |
+
"""Extract potential tags from query text (simple keyword extraction)."""
|
| 283 |
+
# Simple approach: split query into words, filter stopwords
|
| 284 |
+
stopwords = {
|
| 285 |
+
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
| 286 |
+
'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be',
|
| 287 |
+
'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
|
| 288 |
+
'would', 'should', 'could', 'what', 'how', 'why', 'when', 'where', 'who'
|
| 289 |
+
}
|
| 290 |
+
words = query.lower().split()
|
| 291 |
+
tags = [w.strip('.,!?;:()[]{}"\'') for w in words if len(w) > 3 and w not in stopwords]
|
| 292 |
+
return tags[:5] # Return top 5 words as tags
|
| 293 |
+
|
| 294 |
+
def _combine_results(self, vector_results: List[Dict[str, Any]],
|
| 295 |
+
tag_results: List[Dict[str, Any]],
|
| 296 |
+
vector_weight: float, tag_weight: float) -> List[Dict[str, Any]]:
|
| 297 |
+
"""Combine vector and tag search results with weighted scoring."""
|
| 298 |
+
# Create combined dict keyed by chunk_id
|
| 299 |
+
combined_dict = {}
|
| 300 |
+
|
| 301 |
+
# Add vector results
|
| 302 |
+
for result in vector_results:
|
| 303 |
+
chunk_id = result.get('id', result.get('chunk_id', str(id(result))))
|
| 304 |
+
combined_dict[chunk_id] = {
|
| 305 |
+
**result,
|
| 306 |
+
'vector_score': result.get('score', 0.0),
|
| 307 |
+
'tag_score': 0.0
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
# Add tag results and combine scores
|
| 311 |
+
for result in tag_results:
|
| 312 |
+
chunk_id = result.get('id', result.get('chunk_id', str(id(result))))
|
| 313 |
+
tag_score = result.get('score', 0.0)
|
| 314 |
+
|
| 315 |
+
if chunk_id in combined_dict:
|
| 316 |
+
# Combine with existing vector score
|
| 317 |
+
combined_dict[chunk_id]['tag_score'] = tag_score
|
| 318 |
+
else:
|
| 319 |
+
# Add new result with only tag score
|
| 320 |
+
combined_dict[chunk_id] = {
|
| 321 |
+
**result,
|
| 322 |
+
'vector_score': 0.0,
|
| 323 |
+
'tag_score': tag_score
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
# Normalize scores and compute hybrid scores
|
| 327 |
+
vector_scores = [r.get('vector_score', 0.0) for r in combined_dict.values()]
|
| 328 |
+
tag_scores = [r.get('tag_score', 0.0) for r in combined_dict.values()]
|
| 329 |
+
|
| 330 |
+
max_vector_score = max(vector_scores) if vector_scores else 1.0
|
| 331 |
+
min_vector_score = min(vector_scores) if vector_scores else 0.0
|
| 332 |
+
vector_range = max_vector_score - min_vector_score if max_vector_score > min_vector_score else 1.0
|
| 333 |
+
|
| 334 |
+
for chunk_id, result in combined_dict.items():
|
| 335 |
+
vector_score = result.get('vector_score', 0.0)
|
| 336 |
+
tag_score = result.get('tag_score', 0.0)
|
| 337 |
+
|
| 338 |
+
# Normalize vector score to 0-1 range
|
| 339 |
+
if vector_range > 0:
|
| 340 |
+
vector_score = (vector_score - min_vector_score) / vector_range
|
| 341 |
+
else:
|
| 342 |
+
vector_score = 1.0 if vector_score > 0 else 0.0
|
| 343 |
+
|
| 344 |
+
# Ensure tag score is in 0-1 range
|
| 345 |
+
tag_score = max(0.0, min(1.0, tag_score))
|
| 346 |
+
|
| 347 |
+
# Weighted combination
|
| 348 |
+
hybrid_score = (vector_weight * vector_score) + (tag_weight * tag_score)
|
| 349 |
+
|
| 350 |
+
result['score'] = hybrid_score
|
| 351 |
+
result['hybrid_score'] = hybrid_score
|
| 352 |
+
result['vector_score'] = vector_score
|
| 353 |
+
result['tag_score'] = tag_score
|
| 354 |
+
|
| 355 |
+
# Sort by hybrid score (descending)
|
| 356 |
+
sorted_results = sorted(
|
| 357 |
+
combined_dict.values(),
|
| 358 |
+
key=lambda x: x.get('hybrid_score', x.get('score', 0.0)),
|
| 359 |
+
reverse=True
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
return sorted_results
|
| 363 |
+
|
| 364 |
+
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
| 365 |
+
"""Format hybrid retrieval results into text"""
|
| 366 |
+
formatted = []
|
| 367 |
+
for i, result in enumerate(results, 1):
|
| 368 |
+
metadata = result.get('metadata', {})
|
| 369 |
+
tags = metadata.get('tags', [])
|
| 370 |
+
tags_str = ", ".join(tags[:5]) if isinstance(tags, list) else str(tags)
|
| 371 |
+
if isinstance(tags, list) and len(tags) > 5:
|
| 372 |
+
tags_str += "..."
|
| 373 |
+
|
| 374 |
+
vector_score = result.get('vector_score', 0.0)
|
| 375 |
+
tag_score = result.get('tag_score', 0.0)
|
| 376 |
+
hybrid_score = result.get('hybrid_score', result.get('score', 0.0))
|
| 377 |
+
|
| 378 |
+
formatted.append(
|
| 379 |
+
f"[{i}] {result['content'][:200]}...\n"
|
| 380 |
+
f" Tags: {tags_str}\n"
|
| 381 |
+
f" Hybrid Score: {hybrid_score:.3f} "
|
| 382 |
+
f"(Vector: {vector_score:.3f}, Tag: {tag_score:.3f})"
|
| 383 |
+
)
|
| 384 |
+
return "\n\n".join(formatted)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class HybridRerankRAG:
|
| 388 |
+
"""Hybrid RAG pipeline with reranking applied after hybrid retrieval"""
|
| 389 |
+
|
| 390 |
+
def __init__(self, vector_store: VectorStore, collection_name: str = "documents",
|
| 391 |
+
reranker: Optional[Any] = None):
|
| 392 |
+
self.vector_store = vector_store
|
| 393 |
+
self.collection_name = collection_name
|
| 394 |
+
self.hybrid_rag = HybridRAG(vector_store, collection_name)
|
| 395 |
+
|
| 396 |
+
# Lazy import of Reranker
|
| 397 |
+
if reranker is None:
|
| 398 |
+
from .reranker import Reranker
|
| 399 |
+
self.reranker = Reranker()
|
| 400 |
+
else:
|
| 401 |
+
self.reranker = reranker
|
| 402 |
+
|
| 403 |
+
def retrieve(self, query: str, k: int = 5,
|
| 404 |
+
tags: Optional[List[str]] = None,
|
| 405 |
+
tag_operator: str = "OR",
|
| 406 |
+
vector_weight: float = 0.7,
|
| 407 |
+
tag_weight: float = 0.3,
|
| 408 |
+
rerank_top_k: Optional[int] = None) -> RetrievalResult:
|
| 409 |
+
"""
|
| 410 |
+
Retrieve documents using hybrid approach with reranking.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
query: Search query text
|
| 414 |
+
k: Number of results to return
|
| 415 |
+
tags: List of tags for tag search (if None, extracts from query)
|
| 416 |
+
tag_operator: Tag filter operator - "AND", "OR", or "NOT"
|
| 417 |
+
vector_weight: Weight for vector similarity score (0.0-1.0)
|
| 418 |
+
tag_weight: Weight for tag matching score (0.0-1.0)
|
| 419 |
+
rerank_top_k: Number of results to rerank (if None, uses k*2)
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
RetrievalResult with reranked documents
|
| 423 |
+
"""
|
| 424 |
+
start_time = time.time()
|
| 425 |
+
|
| 426 |
+
# Fetch more results for reranking (fetch 2k to rerank, return top k)
|
| 427 |
+
fetch_k = rerank_top_k or (k * 2)
|
| 428 |
+
|
| 429 |
+
# Get hybrid retrieval results
|
| 430 |
+
hybrid_result = self.hybrid_rag.retrieve(
|
| 431 |
+
query=query,
|
| 432 |
+
k=fetch_k,
|
| 433 |
+
tags=tags,
|
| 434 |
+
tag_operator=tag_operator,
|
| 435 |
+
vector_weight=vector_weight,
|
| 436 |
+
tag_weight=tag_weight
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Rerank the results
|
| 440 |
+
reranked_results = self.reranker.rerank(
|
| 441 |
+
query=query,
|
| 442 |
+
results=hybrid_result.sources,
|
| 443 |
+
top_k=k
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
latency = time.time() - start_time
|
| 447 |
+
|
| 448 |
+
return RetrievalResult(
|
| 449 |
+
content=self._format_results(reranked_results),
|
| 450 |
+
sources=reranked_results,
|
| 451 |
+
latency=latency,
|
| 452 |
+
metadata={
|
| 453 |
+
"strategy": "hybrid_rerank_rag",
|
| 454 |
+
"k": k,
|
| 455 |
+
"vector_weight": vector_weight,
|
| 456 |
+
"tag_weight": tag_weight,
|
| 457 |
+
"tags": tags,
|
| 458 |
+
"tag_operator": tag_operator,
|
| 459 |
+
"rerank_top_k": fetch_k
|
| 460 |
+
}
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
def _format_results(self, results: List[Dict[str, Any]]) -> str:
|
| 464 |
+
"""Format reranked retrieval results into text"""
|
| 465 |
+
formatted = []
|
| 466 |
+
for i, result in enumerate(results, 1):
|
| 467 |
+
metadata = result.get('metadata', {})
|
| 468 |
+
tags = metadata.get('tags', [])
|
| 469 |
+
tags_str = ", ".join(tags[:5]) if isinstance(tags, list) else str(tags)
|
| 470 |
+
if isinstance(tags, list) and len(tags) > 5:
|
| 471 |
+
tags_str += "..."
|
| 472 |
+
|
| 473 |
+
rerank_score = result.get('rerank_score', result.get('score', 0.0))
|
| 474 |
+
hybrid_score = result.get('hybrid_score', result.get('score', 0.0))
|
| 475 |
+
vector_score = result.get('vector_score', 0.0)
|
| 476 |
+
tag_score = result.get('tag_score', 0.0)
|
| 477 |
+
|
| 478 |
+
formatted.append(
|
| 479 |
+
f"[{i}] {result['content'][:200]}...\n"
|
| 480 |
+
f" Tags: {tags_str}\n"
|
| 481 |
+
f" Rerank Score: {rerank_score:.3f} "
|
| 482 |
+
f"(Hybrid: {hybrid_score:.3f}, Vector: {vector_score:.3f}, Tag: {tag_score:.3f})"
|
| 483 |
+
)
|
| 484 |
+
return "\n\n".join(formatted)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class RAGManager:
|
| 488 |
+
"""Manager for both RAG pipelines"""
|
| 489 |
+
|
| 490 |
+
def __init__(self, persist_directory: str = "/data/chroma"):
|
| 491 |
+
self.vector_store = VectorStore(persist_directory)
|
| 492 |
+
self.base_rag = BaseRAG(self.vector_store)
|
| 493 |
+
self.hier_rag = HierarchicalRAG(self.vector_store)
|
| 494 |
+
self.tag_filter_rag = TagFilterRAG(self.vector_store)
|
| 495 |
+
self.hybrid_rag = HybridRAG(self.vector_store)
|
| 496 |
+
self.hybrid_rerank_rag = HybridRerankRAG(self.vector_store)
|
| 497 |
+
|
| 498 |
+
def compare_retrieval(self, query: str, k: int = 5,
|
| 499 |
+
level1: Optional[str] = None,
|
| 500 |
+
level2: Optional[str] = None,
|
| 501 |
+
level3: Optional[str] = None,
|
| 502 |
+
doc_type: Optional[str] = None) -> Tuple[RetrievalResult, RetrievalResult]:
|
| 503 |
+
"""Compare Base-RAG vs Hier-RAG"""
|
| 504 |
+
base_result = self.base_rag.retrieve(query, k)
|
| 505 |
+
hier_result = self.hier_rag.retrieve(query, k, level1, level2, level3, doc_type)
|
| 506 |
+
|
| 507 |
+
return base_result, hier_result
|
core/session_manager.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Session Management Module
|
| 3 |
+
|
| 4 |
+
This module provides session isolation for multiple concurrent users,
|
| 5 |
+
ensuring each user has isolated data and retrieval contexts.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import uuid
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
import threading
|
| 12 |
+
from typing import Dict, Optional, Any
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from datetime import datetime, timedelta
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
_logger = logging.getLogger("rag_session_manager")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class UserSession:
|
| 23 |
+
"""Represents a user session with isolated data"""
|
| 24 |
+
session_id: str
|
| 25 |
+
user_id: str = ""
|
| 26 |
+
collection_name: str = ""
|
| 27 |
+
created_at: datetime = field(default_factory=datetime.now)
|
| 28 |
+
last_accessed: datetime = field(default_factory=datetime.now)
|
| 29 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 30 |
+
|
| 31 |
+
def update_access(self):
|
| 32 |
+
"""Update last accessed time"""
|
| 33 |
+
self.last_accessed = datetime.now()
|
| 34 |
+
|
| 35 |
+
def is_expired(self, timeout_seconds: int = 3600) -> bool:
|
| 36 |
+
"""Check if session has expired"""
|
| 37 |
+
elapsed = (datetime.now() - self.last_accessed).total_seconds()
|
| 38 |
+
return elapsed > timeout_seconds
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SessionManager:
|
| 42 |
+
"""
|
| 43 |
+
Manages user sessions with isolated data contexts.
|
| 44 |
+
|
| 45 |
+
Each session gets:
|
| 46 |
+
- Unique collection name in ChromaDB
|
| 47 |
+
- Isolated RAG manager instance
|
| 48 |
+
- Session-specific data storage
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self,
|
| 52 |
+
base_persist_dir: str = "./chroma_data",
|
| 53 |
+
session_timeout: int = 3600,
|
| 54 |
+
auto_cleanup: bool = True,
|
| 55 |
+
cleanup_interval: int = 300):
|
| 56 |
+
"""
|
| 57 |
+
Initialize session manager.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
base_persist_dir: Base directory for ChromaDB persistence
|
| 61 |
+
session_timeout: Session timeout in seconds (default: 1 hour)
|
| 62 |
+
auto_cleanup: Enable automatic session cleanup
|
| 63 |
+
cleanup_interval: Cleanup interval in seconds (default: 5 minutes)
|
| 64 |
+
"""
|
| 65 |
+
self.base_persist_dir = base_persist_dir
|
| 66 |
+
self.session_timeout = session_timeout
|
| 67 |
+
self.auto_cleanup = auto_cleanup
|
| 68 |
+
self.cleanup_interval = cleanup_interval
|
| 69 |
+
|
| 70 |
+
# Session storage: session_id -> UserSession
|
| 71 |
+
self.sessions: Dict[str, UserSession] = {}
|
| 72 |
+
|
| 73 |
+
# User to session mapping: user_id -> [session_ids]
|
| 74 |
+
self.user_sessions: Dict[str, list] = defaultdict(list)
|
| 75 |
+
|
| 76 |
+
# Thread lock for thread-safe operations
|
| 77 |
+
self._lock = threading.Lock()
|
| 78 |
+
|
| 79 |
+
# Cleanup thread (if enabled)
|
| 80 |
+
self._cleanup_thread = None
|
| 81 |
+
self._stop_cleanup = threading.Event()
|
| 82 |
+
|
| 83 |
+
if self.auto_cleanup:
|
| 84 |
+
self._start_cleanup_thread()
|
| 85 |
+
|
| 86 |
+
def _start_cleanup_thread(self):
|
| 87 |
+
"""Start background thread for session cleanup"""
|
| 88 |
+
def cleanup_loop():
|
| 89 |
+
while not self._stop_cleanup.is_set():
|
| 90 |
+
self._stop_cleanup.wait(self.cleanup_interval)
|
| 91 |
+
if not self._stop_cleanup.is_set():
|
| 92 |
+
self.cleanup_expired_sessions()
|
| 93 |
+
|
| 94 |
+
self._cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
|
| 95 |
+
self._cleanup_thread.start()
|
| 96 |
+
_logger.info("Session cleanup thread started")
|
| 97 |
+
|
| 98 |
+
def create_session(self, user_id: Optional[str] = None) -> UserSession:
|
| 99 |
+
"""
|
| 100 |
+
Create a new user session.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
user_id: Optional user identifier (e.g., from Gradio)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
UserSession object
|
| 107 |
+
"""
|
| 108 |
+
with self._lock:
|
| 109 |
+
# Generate session ID
|
| 110 |
+
session_id = str(uuid.uuid4())
|
| 111 |
+
|
| 112 |
+
# Generate collection name
|
| 113 |
+
collection_name = f"session_{session_id[:8]}"
|
| 114 |
+
|
| 115 |
+
# Create session
|
| 116 |
+
session = UserSession(
|
| 117 |
+
session_id=session_id,
|
| 118 |
+
user_id=user_id or f"user_{session_id[:8]}",
|
| 119 |
+
collection_name=collection_name
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Store session
|
| 123 |
+
self.sessions[session_id] = session
|
| 124 |
+
|
| 125 |
+
# Map user to session
|
| 126 |
+
self.user_sessions[session.user_id].append(session_id)
|
| 127 |
+
|
| 128 |
+
_logger.info(f"Created session {session_id} for user {session.user_id}")
|
| 129 |
+
|
| 130 |
+
return session
|
| 131 |
+
|
| 132 |
+
def get_session(self, session_id: str) -> Optional[UserSession]:
|
| 133 |
+
"""
|
| 134 |
+
Get session by ID.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
session_id: Session identifier
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
UserSession or None if not found/expired
|
| 141 |
+
"""
|
| 142 |
+
with self._lock:
|
| 143 |
+
session = self.sessions.get(session_id)
|
| 144 |
+
|
| 145 |
+
if session is None:
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
# Check if expired
|
| 149 |
+
if session.is_expired(self.session_timeout):
|
| 150 |
+
_logger.info(f"Session {session_id} expired")
|
| 151 |
+
self._remove_session(session_id)
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
# Update access time
|
| 155 |
+
session.update_access()
|
| 156 |
+
|
| 157 |
+
return session
|
| 158 |
+
|
| 159 |
+
def get_or_create_session(self, session_id: Optional[str] = None,
|
| 160 |
+
user_id: Optional[str] = None) -> UserSession:
|
| 161 |
+
"""
|
| 162 |
+
Get existing session or create new one.
|
| 163 |
+
If session_id is provided but session doesn't exist in memory,
|
| 164 |
+
checks if ChromaDB collection exists and restores the session.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
session_id: Optional existing session ID
|
| 168 |
+
user_id: Optional user identifier
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
UserSession object
|
| 172 |
+
"""
|
| 173 |
+
if session_id:
|
| 174 |
+
session = self.get_session(session_id)
|
| 175 |
+
if session:
|
| 176 |
+
return session
|
| 177 |
+
|
| 178 |
+
# Session not in memory - check if ChromaDB collection exists
|
| 179 |
+
# This handles server restarts where sessions are lost from memory
|
| 180 |
+
# but ChromaDB collections persist on disk
|
| 181 |
+
collection_name = f"session_{session_id[:8]}"
|
| 182 |
+
if self._collection_exists(collection_name):
|
| 183 |
+
# Restore session from existing ChromaDB collection
|
| 184 |
+
_logger.info(f"Restoring session {session_id} from existing ChromaDB collection {collection_name}")
|
| 185 |
+
session = UserSession(
|
| 186 |
+
session_id=session_id,
|
| 187 |
+
user_id=user_id or f"user_{session_id[:8]}",
|
| 188 |
+
collection_name=collection_name
|
| 189 |
+
)
|
| 190 |
+
with self._lock:
|
| 191 |
+
self.sessions[session_id] = session
|
| 192 |
+
self.user_sessions[session.user_id].append(session_id)
|
| 193 |
+
return session
|
| 194 |
+
|
| 195 |
+
return self.create_session(user_id)
|
| 196 |
+
|
| 197 |
+
def _collection_exists(self, collection_name: str) -> bool:
|
| 198 |
+
"""
|
| 199 |
+
Check if a ChromaDB collection exists.
|
| 200 |
+
Since ChromaDB collections are namespaced by embedding provider/dimension,
|
| 201 |
+
we check if any collection starts with the base collection name.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
collection_name: Base collection name to check (e.g., "session_abc12345")
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
True if collection exists (with any suffix), False otherwise
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
import chromadb
|
| 211 |
+
client = chromadb.PersistentClient(path=self.base_persist_dir)
|
| 212 |
+
collections = client.list_collections()
|
| 213 |
+
collection_names = [col.name for col in collections]
|
| 214 |
+
# Check if any collection name starts with the base collection name
|
| 215 |
+
# because ChromaDB adds suffixes like "__st_384" or "__oai_1536"
|
| 216 |
+
return any(name.startswith(collection_name + "__") for name in collection_names)
|
| 217 |
+
except Exception as e:
|
| 218 |
+
_logger.warning(f"Failed to check if collection {collection_name} exists: {e}")
|
| 219 |
+
return False
|
| 220 |
+
|
| 221 |
+
def _remove_session(self, session_id: str):
|
| 222 |
+
"""Remove session (internal, assumes lock is held)"""
|
| 223 |
+
session = self.sessions.get(session_id)
|
| 224 |
+
if session:
|
| 225 |
+
# Remove from sessions
|
| 226 |
+
del self.sessions[session_id]
|
| 227 |
+
|
| 228 |
+
# Remove from user mapping
|
| 229 |
+
if session.user_id in self.user_sessions:
|
| 230 |
+
try:
|
| 231 |
+
self.user_sessions[session.user_id].remove(session_id)
|
| 232 |
+
except ValueError:
|
| 233 |
+
pass
|
| 234 |
+
|
| 235 |
+
# Clean up empty user entries
|
| 236 |
+
if not self.user_sessions[session.user_id]:
|
| 237 |
+
del self.user_sessions[session.user_id]
|
| 238 |
+
|
| 239 |
+
_logger.info(f"Removed session {session_id}")
|
| 240 |
+
|
| 241 |
+
def remove_session(self, session_id: str):
|
| 242 |
+
"""Remove session (public, thread-safe)"""
|
| 243 |
+
with self._lock:
|
| 244 |
+
self._remove_session(session_id)
|
| 245 |
+
|
| 246 |
+
def cleanup_expired_sessions(self) -> int:
|
| 247 |
+
"""
|
| 248 |
+
Clean up expired sessions.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Number of sessions cleaned up
|
| 252 |
+
"""
|
| 253 |
+
with self._lock:
|
| 254 |
+
expired = [
|
| 255 |
+
session_id for session_id, session in self.sessions.items()
|
| 256 |
+
if session.is_expired(self.session_timeout)
|
| 257 |
+
]
|
| 258 |
+
|
| 259 |
+
for session_id in expired:
|
| 260 |
+
self._remove_session(session_id)
|
| 261 |
+
|
| 262 |
+
if expired:
|
| 263 |
+
_logger.info(f"Cleaned up {len(expired)} expired sessions")
|
| 264 |
+
|
| 265 |
+
return len(expired)
|
| 266 |
+
|
| 267 |
+
def get_user_sessions(self, user_id: str) -> list:
|
| 268 |
+
"""
|
| 269 |
+
Get all sessions for a user.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
user_id: User identifier
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
List of UserSession objects
|
| 276 |
+
"""
|
| 277 |
+
with self._lock:
|
| 278 |
+
session_ids = self.user_sessions.get(user_id, [])
|
| 279 |
+
sessions = []
|
| 280 |
+
|
| 281 |
+
for session_id in session_ids[:]: # Copy list
|
| 282 |
+
session = self.sessions.get(session_id)
|
| 283 |
+
if session:
|
| 284 |
+
if session.is_expired(self.session_timeout):
|
| 285 |
+
self._remove_session(session_id)
|
| 286 |
+
else:
|
| 287 |
+
sessions.append(session)
|
| 288 |
+
|
| 289 |
+
return sessions
|
| 290 |
+
|
| 291 |
+
def get_session_stats(self) -> Dict[str, Any]:
|
| 292 |
+
"""Get statistics about active sessions"""
|
| 293 |
+
with self._lock:
|
| 294 |
+
active_sessions = [
|
| 295 |
+
s for s in self.sessions.values()
|
| 296 |
+
if not s.is_expired(self.session_timeout)
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
return {
|
| 300 |
+
'total_sessions': len(self.sessions),
|
| 301 |
+
'active_sessions': len(active_sessions),
|
| 302 |
+
'unique_users': len(self.user_sessions),
|
| 303 |
+
'expired_sessions': len(self.sessions) - len(active_sessions)
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
def clear_all_sessions(self):
|
| 307 |
+
"""Clear all sessions (for testing/cleanup)"""
|
| 308 |
+
with self._lock:
|
| 309 |
+
self.sessions.clear()
|
| 310 |
+
self.user_sessions.clear()
|
| 311 |
+
_logger.info("Cleared all sessions")
|
| 312 |
+
|
| 313 |
+
def shutdown(self):
|
| 314 |
+
"""Shutdown session manager and cleanup thread"""
|
| 315 |
+
if self._cleanup_thread:
|
| 316 |
+
self._stop_cleanup.set()
|
| 317 |
+
self._cleanup_thread.join(timeout=5)
|
| 318 |
+
_logger.info("Session cleanup thread stopped")
|
| 319 |
+
|
| 320 |
+
def __del__(self):
|
| 321 |
+
"""Cleanup on destruction"""
|
| 322 |
+
self.shutdown()
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# Global session manager instance
|
| 326 |
+
_global_session_manager: Optional[SessionManager] = None
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def get_session_manager(base_persist_dir: str = "./chroma_data",
|
| 330 |
+
session_timeout: int = 3600,
|
| 331 |
+
auto_cleanup: bool = True) -> SessionManager:
|
| 332 |
+
"""
|
| 333 |
+
Get or create global session manager instance.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
base_persist_dir: Base directory for ChromaDB persistence
|
| 337 |
+
session_timeout: Session timeout in seconds
|
| 338 |
+
auto_cleanup: Enable automatic cleanup
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
SessionManager instance
|
| 342 |
+
"""
|
| 343 |
+
global _global_session_manager
|
| 344 |
+
|
| 345 |
+
if _global_session_manager is None:
|
| 346 |
+
_global_session_manager = SessionManager(
|
| 347 |
+
base_persist_dir=base_persist_dir,
|
| 348 |
+
session_timeout=session_timeout,
|
| 349 |
+
auto_cleanup=auto_cleanup
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
return _global_session_manager
|
| 353 |
+
|
core/session_rag.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Session-Aware RAG Manager
|
| 3 |
+
|
| 4 |
+
Helper module to create session-specific RAG managers and handle
|
| 5 |
+
session isolation for multiple concurrent users.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Optional, Dict, Any
|
| 10 |
+
from .retrieval import RAGManager, BaseRAG, TagFilterRAG, HybridRAG, HybridRerankRAG
|
| 11 |
+
from .session_manager import SessionManager, UserSession
|
| 12 |
+
|
| 13 |
+
_logger = logging.getLogger("rag_session_rag")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SessionAwareRAGManager:
|
| 17 |
+
"""
|
| 18 |
+
Session-aware wrapper for RAGManager that provides per-session isolation.
|
| 19 |
+
|
| 20 |
+
This class creates session-specific RAG pipelines by using different
|
| 21 |
+
collection names per session, ensuring data isolation.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, base_rag_manager: RAGManager, session_manager: SessionManager):
|
| 25 |
+
"""
|
| 26 |
+
Initialize session-aware RAG manager.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
base_rag_manager: Base RAGManager instance (shared vector store)
|
| 30 |
+
session_manager: SessionManager instance for session handling
|
| 31 |
+
"""
|
| 32 |
+
self.base_rag_manager = base_rag_manager
|
| 33 |
+
self.session_manager = session_manager
|
| 34 |
+
# Cache of session-specific RAG managers
|
| 35 |
+
self._session_managers: Dict[str, RAGManager] = {}
|
| 36 |
+
|
| 37 |
+
def get_session_manager(self, session: Optional[UserSession]) -> RAGManager:
|
| 38 |
+
"""
|
| 39 |
+
Get RAG manager for a session.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
session: UserSession object or None for default
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
RAGManager instance for the session
|
| 46 |
+
"""
|
| 47 |
+
if session is None:
|
| 48 |
+
# Use default collection
|
| 49 |
+
return self.base_rag_manager
|
| 50 |
+
|
| 51 |
+
# Get or create session-specific manager
|
| 52 |
+
collection_name = session.collection_name
|
| 53 |
+
if collection_name not in self._session_managers:
|
| 54 |
+
# Create new RAGManager with session-specific collection
|
| 55 |
+
session_rag = RAGManager(
|
| 56 |
+
persist_directory=self.base_rag_manager.vector_store.persist_directory
|
| 57 |
+
)
|
| 58 |
+
# Update collection names in all pipelines
|
| 59 |
+
session_rag.base_rag.collection_name = collection_name
|
| 60 |
+
session_rag.tag_filter_rag.collection_name = collection_name
|
| 61 |
+
session_rag.hybrid_rag.collection_name = collection_name
|
| 62 |
+
session_rag.hybrid_rerank_rag.collection_name = collection_name
|
| 63 |
+
session_rag.hybrid_rag.base_rag.collection_name = collection_name
|
| 64 |
+
session_rag.hybrid_rag.tag_filter_rag.collection_name = collection_name
|
| 65 |
+
session_rag.hybrid_rerank_rag.hybrid_rag.collection_name = collection_name
|
| 66 |
+
session_rag.hybrid_rerank_rag.hybrid_rag.base_rag.collection_name = collection_name
|
| 67 |
+
session_rag.hybrid_rerank_rag.hybrid_rag.tag_filter_rag.collection_name = collection_name
|
| 68 |
+
|
| 69 |
+
self._session_managers[collection_name] = session_rag
|
| 70 |
+
_logger.debug(f"Created session RAG manager for collection: {collection_name}")
|
| 71 |
+
|
| 72 |
+
return self._session_managers[collection_name]
|
| 73 |
+
|
| 74 |
+
def get_rag(self, session_id: str) -> RAGManager:
|
| 75 |
+
"""
|
| 76 |
+
Get RAG manager for a session ID.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
session_id: Session identifier
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
RAGManager instance for the session
|
| 83 |
+
"""
|
| 84 |
+
session = self.session_manager.get_session(session_id)
|
| 85 |
+
if session is None:
|
| 86 |
+
# If session doesn't exist, use default collection
|
| 87 |
+
_logger.warning(f"Session {session_id} not found, using default collection")
|
| 88 |
+
return self.base_rag_manager
|
| 89 |
+
return self.get_session_manager(session)
|
| 90 |
+
|
| 91 |
+
def cleanup_session(self, session_id: str):
|
| 92 |
+
"""Clean up session-specific RAG manager"""
|
| 93 |
+
# Find collection name for session
|
| 94 |
+
session = self.session_manager.get_session(session_id)
|
| 95 |
+
if session:
|
| 96 |
+
collection_name = session.collection_name
|
| 97 |
+
if collection_name in self._session_managers:
|
| 98 |
+
del self._session_managers[collection_name]
|
| 99 |
+
_logger.debug(f"Cleaned up session RAG manager for collection: {collection_name}")
|
| 100 |
+
|
core/tag_generator.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tag Generation Module for Flat, Non-Hierarchical Tagging
|
| 3 |
+
|
| 4 |
+
This module provides automatic tag generation using multiple methods:
|
| 5 |
+
1. Keyphrase extraction (YAKE, KeyBERT)
|
| 6 |
+
2. Noun phrase analysis (spaCy, Janome for Japanese)
|
| 7 |
+
3. OpenAI-based generation (optional)
|
| 8 |
+
|
| 9 |
+
Supports both English and Japanese text.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
import logging
|
| 14 |
+
from typing import List, Dict, Any, Optional, Set
|
| 15 |
+
from collections import Counter
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
_logger = logging.getLogger("rag_tag_generator")
|
| 19 |
+
|
| 20 |
+
# Try to import optional dependencies
|
| 21 |
+
try:
|
| 22 |
+
import yake
|
| 23 |
+
YAKE_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
YAKE_AVAILABLE = False
|
| 26 |
+
_logger.debug("YAKE not available, will use fallback methods")
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from keybert import KeyBERT
|
| 30 |
+
KEYBERT_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
KEYBERT_AVAILABLE = False
|
| 33 |
+
_logger.debug("KeyBERT not available, will use fallback methods")
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import spacy
|
| 37 |
+
SPACY_AVAILABLE = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
SPACY_AVAILABLE = False
|
| 40 |
+
_logger.debug("spaCy not available for noun phrase extraction")
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from janome.tokenizer import Tokenizer
|
| 44 |
+
JANOME_AVAILABLE = True
|
| 45 |
+
except ImportError:
|
| 46 |
+
JANOME_AVAILABLE = False
|
| 47 |
+
_logger.debug("Janome not available for Japanese tokenization")
|
| 48 |
+
|
| 49 |
+
# OpenAI for tag generation (optional, fallback)
|
| 50 |
+
_OPENAI_ENABLED = False
|
| 51 |
+
try:
|
| 52 |
+
from openai import OpenAI as _OpenAI
|
| 53 |
+
_OPENAI_ENABLED = True if os.getenv("OPENAI_API_KEY") else False
|
| 54 |
+
except Exception:
|
| 55 |
+
_OPENAI_ENABLED = False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TagGenerator:
|
| 59 |
+
"""
|
| 60 |
+
Automatic flat tag generation from text documents.
|
| 61 |
+
|
| 62 |
+
Supports multiple extraction methods:
|
| 63 |
+
- Keyphrase extraction (YAKE, KeyBERT)
|
| 64 |
+
- Noun phrase analysis (spaCy, Janome)
|
| 65 |
+
- OpenAI-based generation (optional)
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self,
|
| 69 |
+
max_tags: int = 10,
|
| 70 |
+
min_tag_length: int = 2,
|
| 71 |
+
max_tag_length: int = 3,
|
| 72 |
+
language: Optional[str] = None):
|
| 73 |
+
"""
|
| 74 |
+
Initialize tag generator.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
max_tags: Maximum number of tags to generate per chunk
|
| 78 |
+
min_tag_length: Minimum words in a tag phrase
|
| 79 |
+
max_tag_length: Maximum words in a tag phrase
|
| 80 |
+
language: Language code ('en', 'ja') or None for auto-detect
|
| 81 |
+
"""
|
| 82 |
+
self.max_tags = max_tags
|
| 83 |
+
self.min_tag_length = min_tag_length
|
| 84 |
+
self.max_tag_length = max_tag_length
|
| 85 |
+
self.language = language
|
| 86 |
+
|
| 87 |
+
# Initialize models lazily
|
| 88 |
+
self._yake_extractor = None
|
| 89 |
+
self._keybert_model = None
|
| 90 |
+
self._spacy_model = None
|
| 91 |
+
self._janome_tokenizer = None
|
| 92 |
+
self._openai_client = None
|
| 93 |
+
|
| 94 |
+
# Common stopwords for filtering
|
| 95 |
+
self.stopwords_en = {
|
| 96 |
+
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
| 97 |
+
'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be',
|
| 98 |
+
'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
|
| 99 |
+
'would', 'should', 'could', 'may', 'might', 'must', 'can', 'this',
|
| 100 |
+
'that', 'these', 'those', 'it', 'its', 'they', 'them', 'their'
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Japanese stopwords (basic set)
|
| 104 |
+
self.stopwords_ja = {
|
| 105 |
+
'の', 'に', 'は', 'を', 'た', 'が', 'で', 'て', 'と', 'し', 'れ',
|
| 106 |
+
'さ', 'ある', 'いる', 'も', 'する', 'から', 'な', 'こと', 'として',
|
| 107 |
+
'い', 'や', 'れる', 'など', 'なっ', 'たら', 'なり', 'られる', 'など'
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def _detect_language(self, text: str) -> str:
|
| 111 |
+
"""Detect language from text."""
|
| 112 |
+
if self.language:
|
| 113 |
+
return self.language
|
| 114 |
+
|
| 115 |
+
# Simple heuristic: check for Japanese characters
|
| 116 |
+
if any('\u3040' <= ch <= '\u30ff' or '\u4e00' <= ch <= '\u9faf' for ch in text):
|
| 117 |
+
return 'ja'
|
| 118 |
+
return 'en'
|
| 119 |
+
|
| 120 |
+
def _initialize_yake(self, language: str):
|
| 121 |
+
"""Initialize YAKE keyphrase extractor."""
|
| 122 |
+
if not YAKE_AVAILABLE:
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
if self._yake_extractor is None:
|
| 126 |
+
try:
|
| 127 |
+
# YAKE language mapping
|
| 128 |
+
yake_lang = 'ja' if language == 'ja' else 'en'
|
| 129 |
+
self._yake_extractor = yake.KeywordExtractor(
|
| 130 |
+
lan=yake_lang,
|
| 131 |
+
n=self.max_tag_length,
|
| 132 |
+
dedupLim=0.9,
|
| 133 |
+
top=self.max_tags * 2, # Extract more, then filter
|
| 134 |
+
features=None
|
| 135 |
+
)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
_logger.warning(f"Failed to initialize YAKE: {e}")
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
return self._yake_extractor
|
| 141 |
+
|
| 142 |
+
def _initialize_keybert(self):
|
| 143 |
+
"""Initialize KeyBERT model."""
|
| 144 |
+
if not KEYBERT_AVAILABLE:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
if self._keybert_model is None:
|
| 148 |
+
try:
|
| 149 |
+
self._keybert_model = KeyBERT()
|
| 150 |
+
except Exception as e:
|
| 151 |
+
_logger.warning(f"Failed to initialize KeyBERT: {e}")
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
return self._keybert_model
|
| 155 |
+
|
| 156 |
+
def _initialize_spacy(self, language: str):
|
| 157 |
+
"""Initialize spaCy model for noun phrase extraction."""
|
| 158 |
+
if not SPACY_AVAILABLE or language != 'en':
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
if self._spacy_model is None:
|
| 162 |
+
try:
|
| 163 |
+
# Try to load English model
|
| 164 |
+
self._spacy_model = spacy.load("en_core_web_sm")
|
| 165 |
+
except OSError:
|
| 166 |
+
_logger.warning("spaCy English model not found. Install with: python -m spacy download en_core_web_sm")
|
| 167 |
+
return None
|
| 168 |
+
except Exception as e:
|
| 169 |
+
_logger.warning(f"Failed to initialize spaCy: {e}")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
return self._spacy_model
|
| 173 |
+
|
| 174 |
+
def _initialize_janome(self):
|
| 175 |
+
"""Initialize Janome tokenizer for Japanese."""
|
| 176 |
+
if not JANOME_AVAILABLE:
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
if self._janome_tokenizer is None:
|
| 180 |
+
try:
|
| 181 |
+
self._janome_tokenizer = Tokenizer()
|
| 182 |
+
except Exception as e:
|
| 183 |
+
_logger.warning(f"Failed to initialize Janome: {e}")
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
return self._janome_tokenizer
|
| 187 |
+
|
| 188 |
+
def _initialize_openai(self):
|
| 189 |
+
"""Initialize OpenAI client for tag generation."""
|
| 190 |
+
if not _OPENAI_ENABLED:
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
if self._openai_client is None:
|
| 194 |
+
try:
|
| 195 |
+
self._openai_client = _OpenAI()
|
| 196 |
+
except Exception as e:
|
| 197 |
+
_logger.warning(f"Failed to initialize OpenAI: {e}")
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
return self._openai_client
|
| 201 |
+
|
| 202 |
+
def _extract_with_yake(self, text: str, language: str) -> List[str]:
|
| 203 |
+
"""Extract tags using YAKE."""
|
| 204 |
+
extractor = self._initialize_yake(language)
|
| 205 |
+
if not extractor:
|
| 206 |
+
return []
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
keywords = extractor.extract_keywords(text)
|
| 210 |
+
tags = [kw[1] for kw in keywords[:self.max_tags * 2]] # Extract more than needed
|
| 211 |
+
return self._filter_tags(tags, language)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
_logger.warning(f"YAKE extraction failed: {e}")
|
| 214 |
+
return []
|
| 215 |
+
|
| 216 |
+
def _extract_with_keybert(self, text: str, language: str) -> List[str]:
|
| 217 |
+
"""Extract tags using KeyBERT (English only)."""
|
| 218 |
+
if language != 'en':
|
| 219 |
+
return []
|
| 220 |
+
|
| 221 |
+
model = self._initialize_keybert()
|
| 222 |
+
if not model:
|
| 223 |
+
return []
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
keywords = model.extract_keywords(
|
| 227 |
+
text,
|
| 228 |
+
keyphrase_ngram_range=(self.min_tag_length, self.max_tag_length),
|
| 229 |
+
top_n=self.max_tags * 2
|
| 230 |
+
)
|
| 231 |
+
tags = [kw[0] for kw in keywords]
|
| 232 |
+
return self._filter_tags(tags, language)
|
| 233 |
+
except Exception as e:
|
| 234 |
+
_logger.warning(f"KeyBERT extraction failed: {e}")
|
| 235 |
+
return []
|
| 236 |
+
|
| 237 |
+
def _extract_noun_phrases_spacy(self, text: str) -> List[str]:
|
| 238 |
+
"""Extract noun phrases using spaCy (English only)."""
|
| 239 |
+
model = self._initialize_spacy('en')
|
| 240 |
+
if not model:
|
| 241 |
+
return []
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
doc = model(text)
|
| 245 |
+
noun_phrases = []
|
| 246 |
+
for chunk in doc.noun_chunks:
|
| 247 |
+
if self.min_tag_length <= len(chunk.text.split()) <= self.max_tag_length:
|
| 248 |
+
phrase = chunk.text.lower().strip()
|
| 249 |
+
if phrase and phrase not in self.stopwords_en:
|
| 250 |
+
noun_phrases.append(phrase)
|
| 251 |
+
return self._filter_tags(noun_phrases[:self.max_tags * 2], 'en')
|
| 252 |
+
except Exception as e:
|
| 253 |
+
_logger.warning(f"spaCy noun phrase extraction failed: {e}")
|
| 254 |
+
return []
|
| 255 |
+
|
| 256 |
+
def _extract_noun_phrases_janome(self, text: str) -> List[str]:
|
| 257 |
+
"""Extract noun phrases using Janome (Japanese only)."""
|
| 258 |
+
tokenizer = self._initialize_janome()
|
| 259 |
+
if not tokenizer:
|
| 260 |
+
return []
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
tokens = tokenizer.tokenize(text)
|
| 264 |
+
noun_phrases = []
|
| 265 |
+
current_phrase = []
|
| 266 |
+
|
| 267 |
+
for token in tokens:
|
| 268 |
+
# Extract nouns (名詞) and compound nouns
|
| 269 |
+
if token.part_of_speech.split(',')[0] == '名詞':
|
| 270 |
+
current_phrase.append(token.surface)
|
| 271 |
+
else:
|
| 272 |
+
if len(current_phrase) >= self.min_tag_length:
|
| 273 |
+
phrase = ''.join(current_phrase)
|
| 274 |
+
# Filter stopwords
|
| 275 |
+
if phrase and not all(c in self.stopwords_ja for c in phrase):
|
| 276 |
+
noun_phrases.append(phrase)
|
| 277 |
+
current_phrase = []
|
| 278 |
+
|
| 279 |
+
# Handle last phrase
|
| 280 |
+
if len(current_phrase) >= self.min_tag_length:
|
| 281 |
+
phrase = ''.join(current_phrase)
|
| 282 |
+
if phrase and not all(c in self.stopwords_ja for c in phrase):
|
| 283 |
+
noun_phrases.append(phrase)
|
| 284 |
+
|
| 285 |
+
return self._filter_tags(noun_phrases[:self.max_tags], 'ja')
|
| 286 |
+
except Exception as e:
|
| 287 |
+
_logger.warning(f"Janome noun phrase extraction failed: {e}")
|
| 288 |
+
return []
|
| 289 |
+
|
| 290 |
+
def _extract_with_openai(self, text: str, language: str) -> List[str]:
|
| 291 |
+
"""Extract tags using OpenAI (optional fallback)."""
|
| 292 |
+
client = self._initialize_openai()
|
| 293 |
+
if not client:
|
| 294 |
+
return []
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
lang_name = 'Japanese' if language == 'ja' else 'English'
|
| 298 |
+
prompt = (
|
| 299 |
+
f"Extract {self.max_tags} flat, non-hierarchical tags (keywords/phrases) from the following {lang_name} text. "
|
| 300 |
+
f"Tags should be {self.min_tag_length}-{self.max_tag_length} words each. "
|
| 301 |
+
"Return only a JSON array of tag strings, no explanation.\n\n"
|
| 302 |
+
f"Text:\n{text[:2000]}"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
response = client.chat.completions.create(
|
| 306 |
+
model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
|
| 307 |
+
messages=[{"role": "user", "content": prompt}],
|
| 308 |
+
temperature=0.0,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
content = response.choices[0].message.content
|
| 312 |
+
import json
|
| 313 |
+
# Try to extract JSON array
|
| 314 |
+
if content.strip().startswith('['):
|
| 315 |
+
tags = json.loads(content)
|
| 316 |
+
else:
|
| 317 |
+
# Try to find JSON in the response
|
| 318 |
+
import re
|
| 319 |
+
json_match = re.search(r'\[.*\]', content, re.DOTALL)
|
| 320 |
+
if json_match:
|
| 321 |
+
tags = json.loads(json_match.group())
|
| 322 |
+
else:
|
| 323 |
+
# Fallback: split by lines or commas
|
| 324 |
+
tags = [t.strip().strip('"\'') for t in content.split('\n') if t.strip()]
|
| 325 |
+
|
| 326 |
+
return self._filter_tags(tags[:self.max_tags], language)
|
| 327 |
+
except Exception as e:
|
| 328 |
+
_logger.warning(f"OpenAI tag extraction failed: {e}")
|
| 329 |
+
return []
|
| 330 |
+
|
| 331 |
+
def _filter_tags(self, tags: List[str], language: str) -> List[str]:
|
| 332 |
+
"""Filter and normalize tags."""
|
| 333 |
+
stopwords = self.stopwords_ja if language == 'ja' else self.stopwords_en
|
| 334 |
+
filtered = []
|
| 335 |
+
seen = set()
|
| 336 |
+
|
| 337 |
+
for tag in tags:
|
| 338 |
+
if not tag or not isinstance(tag, str):
|
| 339 |
+
continue
|
| 340 |
+
|
| 341 |
+
# Normalize: lowercase, strip
|
| 342 |
+
tag = tag.lower().strip()
|
| 343 |
+
|
| 344 |
+
# Check length
|
| 345 |
+
if language == 'ja':
|
| 346 |
+
# For Japanese, count characters
|
| 347 |
+
if len(tag) < self.min_tag_length:
|
| 348 |
+
continue
|
| 349 |
+
else:
|
| 350 |
+
# For English, count words
|
| 351 |
+
words = tag.split()
|
| 352 |
+
if len(words) < self.min_tag_length or len(words) > self.max_tag_length:
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
# Filter stopwords
|
| 356 |
+
if language == 'ja':
|
| 357 |
+
if all(c in stopwords for c in tag):
|
| 358 |
+
continue
|
| 359 |
+
else:
|
| 360 |
+
words = tag.split()
|
| 361 |
+
if all(w in stopwords for w in words):
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
# Deduplicate
|
| 365 |
+
if tag not in seen:
|
| 366 |
+
seen.add(tag)
|
| 367 |
+
filtered.append(tag)
|
| 368 |
+
|
| 369 |
+
return filtered[:self.max_tags]
|
| 370 |
+
|
| 371 |
+
def generate_tags(self, text: str, methods: Optional[List[str]] = None,
|
| 372 |
+
use_openai: bool = False) -> List[str]:
|
| 373 |
+
"""
|
| 374 |
+
Generate tags from text using specified methods.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
text: Input text
|
| 378 |
+
methods: List of methods ('yake', 'keybert', 'noun_phrases', 'openai', 'all')
|
| 379 |
+
use_openai: Whether to use OpenAI (requires API key)
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
List of generated tags
|
| 383 |
+
"""
|
| 384 |
+
if not text or not text.strip():
|
| 385 |
+
return []
|
| 386 |
+
|
| 387 |
+
# Detect language
|
| 388 |
+
language = self._detect_language(text)
|
| 389 |
+
|
| 390 |
+
# Determine methods to use
|
| 391 |
+
if methods is None or 'all' in methods:
|
| 392 |
+
methods = ['yake']
|
| 393 |
+
if KEYBERT_AVAILABLE and language == 'en':
|
| 394 |
+
methods.append('keybert')
|
| 395 |
+
if SPACY_AVAILABLE and language == 'en':
|
| 396 |
+
methods.append('noun_phrases')
|
| 397 |
+
if JANOME_AVAILABLE and language == 'ja':
|
| 398 |
+
methods.append('noun_phrases')
|
| 399 |
+
if use_openai and _OPENAI_ENABLED:
|
| 400 |
+
methods.append('openai')
|
| 401 |
+
|
| 402 |
+
# Collect tags from all methods
|
| 403 |
+
all_tags = []
|
| 404 |
+
|
| 405 |
+
for method in methods:
|
| 406 |
+
try:
|
| 407 |
+
if method == 'yake':
|
| 408 |
+
tags = self._extract_with_yake(text, language)
|
| 409 |
+
if tags:
|
| 410 |
+
all_tags.extend(tags)
|
| 411 |
+
_logger.debug(f"YAKE extracted {len(tags)} tags")
|
| 412 |
+
elif method == 'keybert' and language == 'en':
|
| 413 |
+
tags = self._extract_with_keybert(text, language)
|
| 414 |
+
if tags:
|
| 415 |
+
all_tags.extend(tags)
|
| 416 |
+
_logger.debug(f"KeyBERT extracted {len(tags)} tags")
|
| 417 |
+
elif method == 'noun_phrases':
|
| 418 |
+
if language == 'en' and SPACY_AVAILABLE:
|
| 419 |
+
tags = self._extract_noun_phrases_spacy(text)
|
| 420 |
+
if tags:
|
| 421 |
+
all_tags.extend(tags)
|
| 422 |
+
_logger.debug(f"spaCy noun phrases: {len(tags)} tags")
|
| 423 |
+
elif language == 'ja' and JANOME_AVAILABLE:
|
| 424 |
+
tags = self._extract_noun_phrases_janome(text)
|
| 425 |
+
if tags:
|
| 426 |
+
all_tags.extend(tags)
|
| 427 |
+
_logger.debug(f"Janome noun phrases: {len(tags)} tags")
|
| 428 |
+
elif method == 'openai' and use_openai:
|
| 429 |
+
tags = self._extract_with_openai(text, language)
|
| 430 |
+
if tags:
|
| 431 |
+
all_tags.extend(tags)
|
| 432 |
+
_logger.debug(f"OpenAI extracted {len(tags)} tags")
|
| 433 |
+
except Exception as e:
|
| 434 |
+
_logger.warning(f"Tag extraction method {method} failed: {e}")
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
# Deduplicate and rank by frequency
|
| 438 |
+
if all_tags:
|
| 439 |
+
tag_counts = Counter(all_tags)
|
| 440 |
+
# Sort by frequency, then alphabetically
|
| 441 |
+
sorted_tags = sorted(tag_counts.items(), key=lambda x: (-x[1], x[0]))
|
| 442 |
+
return [tag for tag, _ in sorted_tags[:self.max_tags]]
|
| 443 |
+
|
| 444 |
+
# Fallback: simple keyword extraction
|
| 445 |
+
return self._fallback_extraction(text, language)
|
| 446 |
+
|
| 447 |
+
def _fallback_extraction(self, text: str, language: str) -> List[str]:
|
| 448 |
+
"""Fallback tag extraction when no methods are available."""
|
| 449 |
+
stopwords = self.stopwords_ja if language == 'ja' else self.stopwords_en
|
| 450 |
+
|
| 451 |
+
if language == 'ja':
|
| 452 |
+
# For Japanese, extract non-stopword characters/words
|
| 453 |
+
words = list(text)
|
| 454 |
+
tags = [w for w in words if w not in stopwords and len(w) > 1]
|
| 455 |
+
else:
|
| 456 |
+
# For English, extract words
|
| 457 |
+
words = re.findall(r'\b\w+\b', text.lower())
|
| 458 |
+
tags = [w for w in words if w not in stopwords and len(w) >= self.min_tag_length]
|
| 459 |
+
|
| 460 |
+
# Return top N by frequency
|
| 461 |
+
tag_counts = Counter(tags)
|
| 462 |
+
sorted_tags = sorted(tag_counts.items(), key=lambda x: (-x[1], x[0]))
|
| 463 |
+
return [tag for tag, _ in sorted_tags[:self.max_tags]]
|
| 464 |
+
|
core/utils.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import re
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict, Any, List, Optional
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
# Try to import tiktoken, but make it optional
|
| 10 |
+
try:
|
| 11 |
+
import tiktoken
|
| 12 |
+
TIKTOKEN_AVAILABLE = True
|
| 13 |
+
except (ImportError, Exception) as e:
|
| 14 |
+
logger.warning(f"tiktoken not available: {e}. Token counting will use fallback method.")
|
| 15 |
+
TIKTOKEN_AVAILABLE = False
|
| 16 |
+
tiktoken = None
|
| 17 |
+
|
| 18 |
+
# Global cache for tiktoken encoding (to avoid repeated download attempts)
|
| 19 |
+
_TIKTOKEN_ENCODING = None
|
| 20 |
+
_TIKTOKEN_LOAD_ATTEMPTED = False
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class Chunk:
|
| 24 |
+
"""Data class for document chunks"""
|
| 25 |
+
doc_id: str
|
| 26 |
+
chunk_id: str
|
| 27 |
+
content: str
|
| 28 |
+
metadata: Dict[str, Any]
|
| 29 |
+
embeddings: Optional[List[float]] = None
|
| 30 |
+
|
| 31 |
+
class TextProcessor:
|
| 32 |
+
"""Text processing utilities"""
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
global _TIKTOKEN_ENCODING, _TIKTOKEN_LOAD_ATTEMPTED
|
| 36 |
+
|
| 37 |
+
# Use cached encoding if available
|
| 38 |
+
if _TIKTOKEN_ENCODING is not None:
|
| 39 |
+
self.encoding = _TIKTOKEN_ENCODING
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
# Try to load encoding only once (cache result)
|
| 43 |
+
self.encoding = None
|
| 44 |
+
if TIKTOKEN_AVAILABLE and not _TIKTOKEN_LOAD_ATTEMPTED:
|
| 45 |
+
_TIKTOKEN_LOAD_ATTEMPTED = True
|
| 46 |
+
try:
|
| 47 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
| 48 |
+
_TIKTOKEN_ENCODING = encoding
|
| 49 |
+
self.encoding = encoding
|
| 50 |
+
except (ConnectionError, OSError, Exception) as e:
|
| 51 |
+
# Handle network errors, connection refused, and other issues
|
| 52 |
+
# This is expected when offline - tiktoken needs to download encoding file on first use
|
| 53 |
+
# Only log info once (not warning, since this is expected offline behavior)
|
| 54 |
+
logger.info(f"Tiktoken encoding not available (offline mode). Using fallback token counting (characters/4).")
|
| 55 |
+
_TIKTOKEN_ENCODING = None # Cache the failure
|
| 56 |
+
self.encoding = None
|
| 57 |
+
elif _TIKTOKEN_LOAD_ATTEMPTED:
|
| 58 |
+
# Already attempted, use cached result (which is None if it failed)
|
| 59 |
+
self.encoding = _TIKTOKEN_ENCODING
|
| 60 |
+
|
| 61 |
+
def count_tokens(self, text: str) -> int:
|
| 62 |
+
"""Count tokens in text using tiktoken if available, otherwise use character-based estimate"""
|
| 63 |
+
if self.encoding:
|
| 64 |
+
try:
|
| 65 |
+
return len(self.encoding.encode(text))
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.debug(f"tiktoken encoding failed: {e}, using fallback")
|
| 68 |
+
|
| 69 |
+
# Fallback: approximate token count (rough estimate: 1 token ≈ 4 characters)
|
| 70 |
+
# This is a simple approximation, not as accurate as tiktoken
|
| 71 |
+
return len(text) // 4
|
| 72 |
+
|
| 73 |
+
def mask_pii(self, text: str) -> str:
|
| 74 |
+
"""Mask personally identifiable information"""
|
| 75 |
+
# Email addresses
|
| 76 |
+
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]', text)
|
| 77 |
+
# Phone numbers
|
| 78 |
+
text = re.sub(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', '[PHONE]', text)
|
| 79 |
+
# Credit card numbers
|
| 80 |
+
text = re.sub(r'\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b', '[CREDIT_CARD]', text)
|
| 81 |
+
# SSN
|
| 82 |
+
text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[SSN]', text)
|
| 83 |
+
return text
|
| 84 |
+
|
| 85 |
+
def clean_text(self, text: str) -> str:
|
| 86 |
+
"""Clean and normalize text"""
|
| 87 |
+
# Remove extra whitespace
|
| 88 |
+
text = re.sub(r'\s+', ' ', text)
|
| 89 |
+
# Remove special characters but keep basic punctuation
|
| 90 |
+
text = re.sub(r'[^\w\s.,!?;:()\-]', '', text)
|
| 91 |
+
return text.strip()
|
| 92 |
+
|
| 93 |
+
def clean_text_preserve_newlines(self, text: str) -> str:
|
| 94 |
+
"""Normalize text but preserve paragraph breaks for chunking.
|
| 95 |
+
- Normalize Windows newlines to \n
|
| 96 |
+
- Trim spaces on each line
|
| 97 |
+
- Collapse 3+ newlines -> 2 newlines (keep blank lines as separators)
|
| 98 |
+
- Collapse multiple spaces within lines
|
| 99 |
+
- Keep basic punctuation
|
| 100 |
+
"""
|
| 101 |
+
# Normalize line endings
|
| 102 |
+
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
| 103 |
+
# Trim spaces on each line
|
| 104 |
+
text = '\n'.join(line.strip() for line in text.split('\n'))
|
| 105 |
+
# Collapse 3+ newlines to 2 newlines
|
| 106 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 107 |
+
# Collapse multiple spaces within lines
|
| 108 |
+
text = re.sub(r'[ \t]+', ' ', text)
|
| 109 |
+
# Remove disallowed characters but keep punctuation and newlines
|
| 110 |
+
text = re.sub(r'[^\w\s\n.,!?;:()\-]', '', text)
|
| 111 |
+
return text.strip()
|
| 112 |
+
|
| 113 |
+
def split_sentences(self, text: str) -> List[str]:
|
| 114 |
+
"""Split text into sentences using simple regex-based approach.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
text: Input text to split
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
List of sentences
|
| 121 |
+
"""
|
| 122 |
+
# Simple sentence splitting: split on sentence-ending punctuation
|
| 123 |
+
# Pattern: period, exclamation, or question mark followed by space or end of string
|
| 124 |
+
sentences = re.split(r'([.!?]+(?:\s+|$))', text)
|
| 125 |
+
|
| 126 |
+
# Combine punctuation with preceding sentence
|
| 127 |
+
result = []
|
| 128 |
+
for i in range(0, len(sentences) - 1, 2):
|
| 129 |
+
sentence = sentences[i]
|
| 130 |
+
if i + 1 < len(sentences):
|
| 131 |
+
sentence += sentences[i + 1]
|
| 132 |
+
sentence = sentence.strip()
|
| 133 |
+
if sentence:
|
| 134 |
+
result.append(sentence)
|
| 135 |
+
|
| 136 |
+
# Handle last sentence if odd number of splits
|
| 137 |
+
if len(sentences) % 2 == 1 and sentences[-1].strip():
|
| 138 |
+
result.append(sentences[-1].strip())
|
| 139 |
+
|
| 140 |
+
# If no sentences found, return the whole text as one sentence
|
| 141 |
+
if not result:
|
| 142 |
+
return [text.strip()] if text.strip() else []
|
| 143 |
+
|
| 144 |
+
return result
|
| 145 |
+
|
| 146 |
+
def generate_id() -> str:
|
| 147 |
+
"""Generate unique ID"""
|
| 148 |
+
return str(uuid.uuid4())
|
core/visualization.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization Module for RAG Evaluation Results
|
| 3 |
+
|
| 4 |
+
This module provides functions to create various charts and visualizations
|
| 5 |
+
for RAG evaluation results including bar charts, line charts, scatter plots,
|
| 6 |
+
box plots, stacked bars, and Pareto charts.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import matplotlib
|
| 13 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 14 |
+
from collections import Counter
|
| 15 |
+
import os
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
# Use non-interactive backend for server environments
|
| 19 |
+
matplotlib.use('Agg')
|
| 20 |
+
|
| 21 |
+
_logger = logging.getLogger("rag_visualization")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class RAGVisualizer:
|
| 25 |
+
"""Visualization utilities for RAG evaluation results"""
|
| 26 |
+
|
| 27 |
+
# Method display names
|
| 28 |
+
METHOD_NAMES = {
|
| 29 |
+
'base_rag': 'Baseline',
|
| 30 |
+
'tag_filter_rag': '+Tags(Filter)',
|
| 31 |
+
'hybrid_rag': 'Hybrid(Weighted)',
|
| 32 |
+
'hybrid_rerank_rag': 'Hybrid+Rerank'
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
"""Initialize visualizer"""
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def bar_chart_avg_performance(self, df: pd.DataFrame, metric: str = 'precision_at_k',
|
| 40 |
+
k_value: Optional[int] = None) -> plt.Figure:
|
| 41 |
+
"""Create bar chart of average performance by method"""
|
| 42 |
+
# Filter by k if specified
|
| 43 |
+
if k_value:
|
| 44 |
+
df = df[df['k'] == k_value]
|
| 45 |
+
|
| 46 |
+
# Aggregate by pipeline
|
| 47 |
+
avg_metrics = df.groupby('pipeline')[metric].mean().reset_index()
|
| 48 |
+
|
| 49 |
+
# Rename pipelines
|
| 50 |
+
avg_metrics['pipeline'] = avg_metrics['pipeline'].map(
|
| 51 |
+
lambda x: self.METHOD_NAMES.get(x, x)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Create figure
|
| 55 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 56 |
+
ax.bar(avg_metrics['pipeline'], avg_metrics[metric])
|
| 57 |
+
ax.set_xlabel('Method')
|
| 58 |
+
ax.set_ylabel(metric.replace('_', ' ').title())
|
| 59 |
+
ax.set_title(f'Average {metric.replace("_", " ").title()} by Method')
|
| 60 |
+
ax.tick_params(axis='x', rotation=45)
|
| 61 |
+
plt.tight_layout()
|
| 62 |
+
|
| 63 |
+
return fig
|
| 64 |
+
|
| 65 |
+
def box_plot_query_variance(self, df: pd.DataFrame, metric: str = 'precision_at_k',
|
| 66 |
+
k_value: Optional[int] = None) -> plt.Figure:
|
| 67 |
+
"""Create box plot showing query-level variance"""
|
| 68 |
+
# Filter by k if specified
|
| 69 |
+
if k_value:
|
| 70 |
+
df = df[df['k'] == k_value]
|
| 71 |
+
|
| 72 |
+
# Prepare data for box plot
|
| 73 |
+
data = [df[df['pipeline'] == pipeline][metric].values
|
| 74 |
+
for pipeline in df['pipeline'].unique()]
|
| 75 |
+
labels = [self.METHOD_NAMES.get(p, p) for p in df['pipeline'].unique()]
|
| 76 |
+
|
| 77 |
+
# Create figure
|
| 78 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 79 |
+
ax.boxplot(data, labels=labels)
|
| 80 |
+
ax.set_xlabel('Method')
|
| 81 |
+
ax.set_ylabel(metric.replace('_', ' ').title())
|
| 82 |
+
ax.set_title(f'Query-Level Variance: {metric.replace("_", " ").title()}')
|
| 83 |
+
ax.tick_params(axis='x', rotation=45)
|
| 84 |
+
plt.tight_layout()
|
| 85 |
+
|
| 86 |
+
return fig
|
| 87 |
+
|
| 88 |
+
def scatter_plot_tags_vs_ndcg(self, df: pd.DataFrame, k_value: Optional[int] = None) -> plt.Figure:
|
| 89 |
+
"""Create scatter plot of tag count vs nDCG"""
|
| 90 |
+
# Filter by k if specified
|
| 91 |
+
if k_value:
|
| 92 |
+
df = df[df['k'] == k_value]
|
| 93 |
+
|
| 94 |
+
# Extract tag counts from metadata if available
|
| 95 |
+
# This would require additional processing in evaluation
|
| 96 |
+
# For now, create a simple scatter
|
| 97 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 98 |
+
|
| 99 |
+
for pipeline in df['pipeline'].unique():
|
| 100 |
+
pipeline_df = df[df['pipeline'] == pipeline]
|
| 101 |
+
ax.scatter(range(len(pipeline_df)), pipeline_df['ndcg_at_k'],
|
| 102 |
+
label=self.METHOD_NAMES.get(pipeline, pipeline), alpha=0.6)
|
| 103 |
+
|
| 104 |
+
ax.set_xlabel('Query Index')
|
| 105 |
+
ax.set_ylabel('nDCG@k')
|
| 106 |
+
ax.set_title('nDCG by Method and Query')
|
| 107 |
+
ax.legend()
|
| 108 |
+
plt.tight_layout()
|
| 109 |
+
|
| 110 |
+
return fig
|
| 111 |
+
|
| 112 |
+
def line_plot_metrics_over_k(self, df: pd.DataFrame, metric: str = 'precision_at_k') -> plt.Figure:
|
| 113 |
+
"""Create line plot showing metric trends over k values"""
|
| 114 |
+
# Aggregate by pipeline and k
|
| 115 |
+
metric_over_k = df.groupby(['pipeline', 'k'])[metric].mean().reset_index()
|
| 116 |
+
|
| 117 |
+
# Create figure
|
| 118 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 119 |
+
|
| 120 |
+
for pipeline in df['pipeline'].unique():
|
| 121 |
+
pipeline_df = metric_over_k[metric_over_k['pipeline'] == pipeline]
|
| 122 |
+
ax.plot(pipeline_df['k'], pipeline_df[metric],
|
| 123 |
+
marker='o', label=self.METHOD_NAMES.get(pipeline, pipeline))
|
| 124 |
+
|
| 125 |
+
ax.set_xlabel('k (Number of Results)')
|
| 126 |
+
ax.set_ylabel(metric.replace('_', ' ').title())
|
| 127 |
+
ax.set_title(f'{metric.replace("_", " ").title()} Trends Over k Values')
|
| 128 |
+
ax.legend()
|
| 129 |
+
ax.grid(True, alpha=0.3)
|
| 130 |
+
plt.tight_layout()
|
| 131 |
+
|
| 132 |
+
return fig
|
| 133 |
+
|
| 134 |
+
def stacked_bar_user_ratings(self, df: pd.DataFrame) -> plt.Figure:
|
| 135 |
+
"""Create stacked bar chart of user satisfaction ratings"""
|
| 136 |
+
if 'user_satisfaction' not in df.columns:
|
| 137 |
+
# Create empty figure if no data
|
| 138 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 139 |
+
ax.text(0.5, 0.5, 'No user satisfaction data available',
|
| 140 |
+
ha='center', va='center', transform=ax.transAxes)
|
| 141 |
+
return fig
|
| 142 |
+
|
| 143 |
+
# Group by pipeline and satisfaction score
|
| 144 |
+
ratings = df.groupby(['pipeline', 'user_satisfaction']).size().unstack(fill_value=0)
|
| 145 |
+
|
| 146 |
+
# Rename pipelines
|
| 147 |
+
ratings.index = [self.METHOD_NAMES.get(p, p) for p in ratings.index]
|
| 148 |
+
|
| 149 |
+
# Create figure
|
| 150 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 151 |
+
ratings.plot(kind='bar', stacked=True, ax=ax)
|
| 152 |
+
ax.set_xlabel('Method')
|
| 153 |
+
ax.set_ylabel('Count')
|
| 154 |
+
ax.set_title('User Satisfaction Ratings Distribution')
|
| 155 |
+
ax.legend(title='Rating (1-5)')
|
| 156 |
+
ax.tick_params(axis='x', rotation=45)
|
| 157 |
+
plt.tight_layout()
|
| 158 |
+
|
| 159 |
+
return fig
|
| 160 |
+
|
| 161 |
+
def pareto_chart_method_ranking(self, df: pd.DataFrame, metric: str = 'precision_at_k',
|
| 162 |
+
k_value: Optional[int] = None) -> plt.Figure:
|
| 163 |
+
"""Create Pareto chart ranking methods by performance"""
|
| 164 |
+
# Filter by k if specified
|
| 165 |
+
if k_value:
|
| 166 |
+
df = df[df['k'] == k_value]
|
| 167 |
+
|
| 168 |
+
# Aggregate by pipeline
|
| 169 |
+
avg_metrics = df.groupby('pipeline')[metric].mean().reset_index()
|
| 170 |
+
avg_metrics = avg_metrics.sort_values(metric, ascending=True)
|
| 171 |
+
|
| 172 |
+
# Rename pipelines
|
| 173 |
+
avg_metrics['pipeline_display'] = avg_metrics['pipeline'].map(
|
| 174 |
+
lambda x: self.METHOD_NAMES.get(x, x)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Calculate cumulative percentage
|
| 178 |
+
total = avg_metrics[metric].sum()
|
| 179 |
+
avg_metrics['cumulative_pct'] = (avg_metrics[metric].cumsum() / total * 100)
|
| 180 |
+
|
| 181 |
+
# Create figure with dual y-axes
|
| 182 |
+
fig, ax1 = plt.subplots(figsize=(10, 6))
|
| 183 |
+
|
| 184 |
+
# Bar chart
|
| 185 |
+
ax1.barh(avg_metrics['pipeline_display'], avg_metrics[metric],
|
| 186 |
+
color='steelblue', alpha=0.7)
|
| 187 |
+
ax1.set_xlabel(metric.replace('_', ' ').title())
|
| 188 |
+
ax1.set_ylabel('Method')
|
| 189 |
+
ax1.set_title(f'Method Performance Ranking (Pareto Chart) - {metric.replace("_", " ").title()}')
|
| 190 |
+
|
| 191 |
+
# Cumulative line
|
| 192 |
+
ax2 = ax1.twinx()
|
| 193 |
+
ax2.plot(avg_metrics[metric].values, avg_metrics['cumulative_pct'].values,
|
| 194 |
+
'ro-', linewidth=2, markersize=8)
|
| 195 |
+
ax2.set_ylabel('Cumulative Percentage (%)', color='red')
|
| 196 |
+
ax2.tick_params(axis='y', labelcolor='red')
|
| 197 |
+
ax2.set_ylim([0, 105])
|
| 198 |
+
|
| 199 |
+
plt.tight_layout()
|
| 200 |
+
return fig
|
| 201 |
+
|
| 202 |
+
def save_figure(self, fig: plt.Figure, path: str):
|
| 203 |
+
"""Save figure to file"""
|
| 204 |
+
fig.savefig(path, dpi=150, bbox_inches='tight')
|
| 205 |
+
plt.close(fig)
|
| 206 |
+
|
| 207 |
+
def create_all_charts(self, df: pd.DataFrame, output_dir: str = "reports/visualizations",
|
| 208 |
+
k_value: Optional[int] = None) -> Dict[str, str]:
|
| 209 |
+
"""
|
| 210 |
+
Create all available charts.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
df: Evaluation results DataFrame
|
| 214 |
+
output_dir: Output directory for charts
|
| 215 |
+
k_value: Optional k value to filter
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
Dict mapping chart name to file path (keys: 'bar', 'line', 'scatter', 'box', 'stacked_bar', 'pareto')
|
| 219 |
+
"""
|
| 220 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 221 |
+
chart_paths = {}
|
| 222 |
+
|
| 223 |
+
# Bar chart: Average performance
|
| 224 |
+
try:
|
| 225 |
+
fig = self.bar_chart_avg_performance(df, metric='precision_at_k', k_value=k_value)
|
| 226 |
+
path = os.path.join(output_dir, 'bar_avg_performance.png')
|
| 227 |
+
self.save_figure(fig, path)
|
| 228 |
+
chart_paths['bar'] = path
|
| 229 |
+
except Exception as e:
|
| 230 |
+
_logger.warning(f"Error creating bar chart: {e}")
|
| 231 |
+
|
| 232 |
+
# Line plot: Metrics over k values
|
| 233 |
+
try:
|
| 234 |
+
fig = self.line_plot_metrics_over_k(df, metric='precision_at_k')
|
| 235 |
+
path = os.path.join(output_dir, 'line_metrics_over_k.png')
|
| 236 |
+
self.save_figure(fig, path)
|
| 237 |
+
chart_paths['line'] = path
|
| 238 |
+
except Exception as e:
|
| 239 |
+
_logger.warning(f"Error creating line plot: {e}")
|
| 240 |
+
|
| 241 |
+
# Scatter plot: Tags vs nDCG
|
| 242 |
+
try:
|
| 243 |
+
fig = self.scatter_plot_tags_vs_ndcg(df, k_value=k_value)
|
| 244 |
+
path = os.path.join(output_dir, 'scatter_tags_vs_ndcg.png')
|
| 245 |
+
self.save_figure(fig, path)
|
| 246 |
+
chart_paths['scatter'] = path
|
| 247 |
+
except Exception as e:
|
| 248 |
+
_logger.warning(f"Error creating scatter plot: {e}")
|
| 249 |
+
|
| 250 |
+
# Box plot: Query variance
|
| 251 |
+
try:
|
| 252 |
+
fig = self.box_plot_query_variance(df, metric='precision_at_k', k_value=k_value)
|
| 253 |
+
path = os.path.join(output_dir, 'box_query_variance.png')
|
| 254 |
+
self.save_figure(fig, path)
|
| 255 |
+
chart_paths['box'] = path
|
| 256 |
+
except Exception as e:
|
| 257 |
+
_logger.warning(f"Error creating box plot: {e}")
|
| 258 |
+
|
| 259 |
+
# Stacked bar: User ratings (if available)
|
| 260 |
+
if 'user_satisfaction' in df.columns and df['user_satisfaction'].notna().any():
|
| 261 |
+
try:
|
| 262 |
+
fig = self.stacked_bar_user_ratings(df)
|
| 263 |
+
path = os.path.join(output_dir, 'stacked_bar_user_ratings.png')
|
| 264 |
+
self.save_figure(fig, path)
|
| 265 |
+
chart_paths['stacked_bar'] = path
|
| 266 |
+
except Exception as e:
|
| 267 |
+
_logger.warning(f"Error creating stacked bar chart: {e}")
|
| 268 |
+
else:
|
| 269 |
+
# Create placeholder chart if no user satisfaction data
|
| 270 |
+
try:
|
| 271 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 272 |
+
ax.text(0.5, 0.5, 'No user satisfaction data available',
|
| 273 |
+
ha='center', va='center', transform=ax.transAxes, fontsize=14)
|
| 274 |
+
ax.set_title('User Satisfaction Ratings Distribution')
|
| 275 |
+
path = os.path.join(output_dir, 'stacked_bar_user_ratings.png')
|
| 276 |
+
self.save_figure(fig, path)
|
| 277 |
+
chart_paths['stacked_bar'] = path
|
| 278 |
+
except Exception as e:
|
| 279 |
+
_logger.warning(f"Error creating placeholder stacked bar chart: {e}")
|
| 280 |
+
|
| 281 |
+
# Pareto chart: Method ranking
|
| 282 |
+
try:
|
| 283 |
+
fig = self.pareto_chart_method_ranking(df, metric='precision_at_k', k_value=k_value)
|
| 284 |
+
path = os.path.join(output_dir, 'pareto_method_ranking.png')
|
| 285 |
+
self.save_figure(fig, path)
|
| 286 |
+
chart_paths['pareto'] = path
|
| 287 |
+
except Exception as e:
|
| 288 |
+
_logger.warning(f"Error creating pareto chart: {e}")
|
| 289 |
+
|
| 290 |
+
return chart_paths
|
| 291 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt
|
| 2 |
+
# Auto Tagging RAG System Dependencies
|
| 3 |
+
|
| 4 |
+
# Core Gradio dependencies (v5 UI components)
|
| 5 |
+
gradio==5.49.1
|
| 6 |
+
gradio-client==1.13.3
|
| 7 |
+
|
| 8 |
+
# LangChain dependencies
|
| 9 |
+
langchain>=0.1.0
|
| 10 |
+
langchain-community>=0.0.0
|
| 11 |
+
|
| 12 |
+
# Vector database
|
| 13 |
+
chromadb>=0.4.0
|
| 14 |
+
|
| 15 |
+
# PDF processing (pypdf is newer, PyPDF2 is legacy but still used)
|
| 16 |
+
pypdf>=3.0.0
|
| 17 |
+
PyPDF2>=3.0.0
|
| 18 |
+
|
| 19 |
+
# Embeddings and NLP
|
| 20 |
+
sentence-transformers>=2.2.0
|
| 21 |
+
tiktoken>=0.5.0
|
| 22 |
+
|
| 23 |
+
# Tag generation (English and Japanese)
|
| 24 |
+
yake>=0.4.0
|
| 25 |
+
keybert>=0.8.0
|
| 26 |
+
spacy>=3.7.0
|
| 27 |
+
janome>=0.5.0
|
| 28 |
+
|
| 29 |
+
# OpenAI (optional, for API usage)
|
| 30 |
+
openai>=1.0.0
|
| 31 |
+
|
| 32 |
+
# Testing
|
| 33 |
+
pytest>=7.0.0
|
| 34 |
+
pytest-asyncio>=0.21.0
|
| 35 |
+
|
| 36 |
+
# Utilities
|
| 37 |
+
python-dotenv>=1.0.0
|
| 38 |
+
PyYAML>=6.0 # Note: pip package name is PyYAML, not pyyaml
|
| 39 |
+
|
| 40 |
+
# Data processing
|
| 41 |
+
numpy>=1.21.0
|
| 42 |
+
pandas>=1.5.0
|
| 43 |
+
scikit-learn>=1.2.0
|
| 44 |
+
|
| 45 |
+
# Visualization and reports
|
| 46 |
+
matplotlib>=3.5.0
|
| 47 |
+
jinja2>=3.1.0
|
| 48 |
+
|
| 49 |
+
# MCP Server
|
| 50 |
+
mcp>=1.0.0
|
| 51 |
+
|
| 52 |
+
# FastAPI (for MCP/API endpoints)
|
| 53 |
+
fastapi>=0.110.0
|
| 54 |
+
starlette>=0.36.3
|
| 55 |
+
uvicorn>=0.23.0
|
tests/README.md
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Test Suite - Auto Tagging RAG System
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This directory contains comprehensive pytest test cases for the Auto Tagging RAG System, including tests for the MCP server, accuracy, user experience, robustness, and non-technical user scenarios.
|
| 6 |
+
|
| 7 |
+
## Test Files
|
| 8 |
+
|
| 9 |
+
### `test_mcp_server.py`
|
| 10 |
+
**MCP Server Tests** - Tests for Model Context Protocol server functionality
|
| 11 |
+
- Tool listing (`list_tools`)
|
| 12 |
+
- Document search tool (`search_documents`) with all pipelines
|
| 13 |
+
- Evaluation tool (`evaluate_retrieval`)
|
| 14 |
+
- Error handling and edge cases
|
| 15 |
+
- Tag operators (OR/AND/NOT)
|
| 16 |
+
- Default parameter handling
|
| 17 |
+
|
| 18 |
+
### `test_accuracy.py`
|
| 19 |
+
**Accuracy Tests** - Tests for tag generation, retrieval, and evaluation metrics
|
| 20 |
+
- Tag generation accuracy (YAKE, KeyBERT, spaCy, Janome)
|
| 21 |
+
- Retrieval accuracy across all pipelines
|
| 22 |
+
- Evaluation metrics (Precision@k, nDCG@k, MRR)
|
| 23 |
+
- Metric range validation
|
| 24 |
+
|
| 25 |
+
### `test_ux.py`
|
| 26 |
+
**User Experience Tests** - Tests for user workflows and interface
|
| 27 |
+
- Document upload workflow
|
| 28 |
+
- Manual tag input
|
| 29 |
+
- Search workflows
|
| 30 |
+
- Evaluation workflows
|
| 31 |
+
- Session persistence
|
| 32 |
+
- Tag visualization
|
| 33 |
+
- Document count display
|
| 34 |
+
|
| 35 |
+
### `test_robustness.py`
|
| 36 |
+
**Robustness Tests** - Tests for error handling and edge cases
|
| 37 |
+
- Empty query handling
|
| 38 |
+
- Invalid k values
|
| 39 |
+
- Missing tags
|
| 40 |
+
- Invalid operators
|
| 41 |
+
- Empty documents
|
| 42 |
+
- Special characters
|
| 43 |
+
- Large k values
|
| 44 |
+
- Data integrity
|
| 45 |
+
- Performance tests
|
| 46 |
+
|
| 47 |
+
### `test_user_scenarios.py`
|
| 48 |
+
**Non-Technical User Scenarios** - Tests for users without technical knowledge
|
| 49 |
+
- First-time user document upload
|
| 50 |
+
- Simple search queries
|
| 51 |
+
- Evaluation with sample queries
|
| 52 |
+
- Custom tag input
|
| 53 |
+
- Session persistence
|
| 54 |
+
- Real-world workflows
|
| 55 |
+
|
| 56 |
+
### `test_japanese_support.py`
|
| 57 |
+
**Japanese Language Support** - Tests for Japanese language processing
|
| 58 |
+
- Japanese tag generation
|
| 59 |
+
- Language detection
|
| 60 |
+
- Japanese document processing
|
| 61 |
+
- Japanese search queries
|
| 62 |
+
- Mixed language handling
|
| 63 |
+
|
| 64 |
+
### `conftest.py`
|
| 65 |
+
**Pytest Fixtures** - Shared fixtures for all tests
|
| 66 |
+
- Temporary persistence directories
|
| 67 |
+
- RAGManager instances
|
| 68 |
+
- Evaluator instances
|
| 69 |
+
- Session managers
|
| 70 |
+
- Sample documents and queries
|
| 71 |
+
- MCP server instances
|
| 72 |
+
- Populated RAG managers
|
| 73 |
+
|
| 74 |
+
## Running Tests
|
| 75 |
+
|
| 76 |
+
### Install Dependencies
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
pip install -r requirements.txt
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Note: `pytest>=7.0.0` and `pytest-asyncio>=0.21.0` are included in `requirements.txt`.
|
| 83 |
+
|
| 84 |
+
### Run All Tests
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
pytest tests/ -v
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Run Specific Test File
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
# MCP Server tests
|
| 94 |
+
pytest tests/test_mcp_server.py -v
|
| 95 |
+
|
| 96 |
+
# Accuracy tests
|
| 97 |
+
pytest tests/test_accuracy.py -v
|
| 98 |
+
|
| 99 |
+
# UX tests
|
| 100 |
+
pytest tests/test_ux.py -v
|
| 101 |
+
|
| 102 |
+
# Robustness tests
|
| 103 |
+
pytest tests/test_robustness.py -v
|
| 104 |
+
|
| 105 |
+
# User scenario tests
|
| 106 |
+
pytest tests/test_user_scenarios.py -v
|
| 107 |
+
|
| 108 |
+
# Japanese support tests
|
| 109 |
+
pytest tests/test_japanese_support.py -v
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### Run Specific Test Class
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
pytest tests/test_mcp_server.py::TestMCPServer -v
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Run Specific Test Case
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
pytest tests/test_mcp_server.py::TestMCPServer::test_search_documents_base_rag -v
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### Run with Coverage
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
pip install pytest-cov
|
| 128 |
+
pytest tests/ --cov=core --cov=app --cov-report=html
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
### Run Asynchronous Tests
|
| 132 |
+
|
| 133 |
+
MCP server tests use `pytest.mark.asyncio`. Ensure `pytest-asyncio` is installed:
|
| 134 |
+
|
| 135 |
+
```bash
|
| 136 |
+
pip install pytest-asyncio
|
| 137 |
+
pytest tests/test_mcp_server.py -v
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## Test Structure
|
| 141 |
+
|
| 142 |
+
### Test Categories
|
| 143 |
+
|
| 144 |
+
1. **MCP Server Tests** (15+ tests)
|
| 145 |
+
- Tool listing and discovery
|
| 146 |
+
- All pipeline types (Base, Tag Filter, Hybrid, Hybrid Rerank)
|
| 147 |
+
- Tag operators
|
| 148 |
+
- Evaluation tool
|
| 149 |
+
- Error handling
|
| 150 |
+
|
| 151 |
+
2. **Accuracy Tests** (10+ tests)
|
| 152 |
+
- Tag generation for English and Japanese
|
| 153 |
+
- Retrieval accuracy
|
| 154 |
+
- Evaluation metrics
|
| 155 |
+
|
| 156 |
+
3. **User Experience Tests** (5+ tests)
|
| 157 |
+
- Workflow tests
|
| 158 |
+
- Interface tests
|
| 159 |
+
- User interactions
|
| 160 |
+
|
| 161 |
+
4. **Robustness Tests** (10+ tests)
|
| 162 |
+
- Error handling
|
| 163 |
+
- Edge cases
|
| 164 |
+
- Data integrity
|
| 165 |
+
- Performance
|
| 166 |
+
|
| 167 |
+
5. **User Scenario Tests** (8+ tests)
|
| 168 |
+
- Non-technical user workflows
|
| 169 |
+
- Real-world scenarios
|
| 170 |
+
|
| 171 |
+
6. **Japanese Support Tests** (5+ tests)
|
| 172 |
+
- Japanese language processing
|
| 173 |
+
- Mixed language handling
|
| 174 |
+
|
| 175 |
+
**Total**: 50+ test cases
|
| 176 |
+
|
| 177 |
+
## Test Fixtures
|
| 178 |
+
|
| 179 |
+
### Available Fixtures
|
| 180 |
+
|
| 181 |
+
- `temp_persist_dir`: Temporary directory for ChromaDB (auto-cleanup)
|
| 182 |
+
- `rag_manager`: RAGManager instance with temporary persistence
|
| 183 |
+
- `evaluator`: RAGEvaluator instance
|
| 184 |
+
- `session_manager`: SessionManager instance
|
| 185 |
+
- `session_rag_manager`: SessionAwareRAGManager instance
|
| 186 |
+
- `sample_documents`: Sample document data (emergency, medical, surgery)
|
| 187 |
+
- `sample_queries`: Sample evaluation queries
|
| 188 |
+
- `mcp_server`: MCP server instance for testing
|
| 189 |
+
- `populated_rag_manager`: RAGManager with sample documents pre-indexed
|
| 190 |
+
|
| 191 |
+
## Writing New Tests
|
| 192 |
+
|
| 193 |
+
### Example Test Case
|
| 194 |
+
|
| 195 |
+
```python
|
| 196 |
+
def test_my_feature(populated_rag_manager):
|
| 197 |
+
"""Test my new feature"""
|
| 198 |
+
query = "test query"
|
| 199 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 200 |
+
|
| 201 |
+
assert result is not None
|
| 202 |
+
assert len(result.sources) > 0
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
### Using Fixtures
|
| 206 |
+
|
| 207 |
+
```python
|
| 208 |
+
def test_with_fixture(rag_manager, sample_documents):
|
| 209 |
+
"""Test using fixtures"""
|
| 210 |
+
doc_data = sample_documents["emergency"]
|
| 211 |
+
# Use doc_data for testing
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
### Async Tests (MCP Server)
|
| 215 |
+
|
| 216 |
+
```python
|
| 217 |
+
@pytest.mark.asyncio
|
| 218 |
+
async def test_mcp_tool(mcp_server):
|
| 219 |
+
"""Test MCP tool"""
|
| 220 |
+
result = await mcp_server.call_tool("search_documents", {"query": "test"})
|
| 221 |
+
assert result is not None
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
## Requirements
|
| 225 |
+
|
| 226 |
+
### Python Version
|
| 227 |
+
- Python 3.8+
|
| 228 |
+
|
| 229 |
+
### Dependencies
|
| 230 |
+
- `pytest>=7.0.0`
|
| 231 |
+
- `pytest-asyncio>=0.21.0` (for MCP server tests)
|
| 232 |
+
- `pytest-cov` (optional, for coverage)
|
| 233 |
+
|
| 234 |
+
### Models
|
| 235 |
+
- spaCy English model: `python -m spacy download en_core_web_sm`
|
| 236 |
+
- SentenceTransformers: Downloads automatically
|
| 237 |
+
- MCP package: `pip install mcp` (for MCP server tests)
|
| 238 |
+
|
| 239 |
+
## Skipped Tests
|
| 240 |
+
|
| 241 |
+
Some tests may be skipped if optional dependencies are not installed:
|
| 242 |
+
- MCP server tests: Skipped if `mcp` package not installed
|
| 243 |
+
- spaCy tests: Skipped if spaCy model not available
|
| 244 |
+
- OpenAI tests: Skipped if OpenAI API key not configured
|
| 245 |
+
|
| 246 |
+
## Test Maintenance
|
| 247 |
+
|
| 248 |
+
- Update tests when features change
|
| 249 |
+
- Add new tests for new features
|
| 250 |
+
- Review and refine tests regularly
|
| 251 |
+
- Keep test data updated
|
| 252 |
+
|
| 253 |
+
## Troubleshooting
|
| 254 |
+
|
| 255 |
+
### Import Errors
|
| 256 |
+
```bash
|
| 257 |
+
# Ensure you're in the project root
|
| 258 |
+
cd /path/to/auto_tagging_rag
|
| 259 |
+
pytest tests/ -v
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
### MCP Tests Fail
|
| 263 |
+
```bash
|
| 264 |
+
# Install MCP package
|
| 265 |
+
pip install mcp
|
| 266 |
+
pytest tests/test_mcp_server.py -v
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
### Model Not Found Errors
|
| 270 |
+
```bash
|
| 271 |
+
# Download required models
|
| 272 |
+
python -m spacy download en_core_web_sm
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
### Session Errors
|
| 276 |
+
- Tests use temporary directories (auto-cleanup)
|
| 277 |
+
- If tests fail, check temp directory permissions
|
| 278 |
+
|
| 279 |
+
---
|
| 280 |
+
|
| 281 |
+
**Last Updated**: 2024
|
| 282 |
+
**Test Suite Version**: 1.0
|
| 283 |
+
|
tests/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tests package for Auto Tagging RAG System
|
| 2 |
+
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pytest fixtures for Auto Tagging RAG System tests
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
import shutil
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Import core modules
|
| 11 |
+
import sys
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 13 |
+
|
| 14 |
+
from core.retrieval import RAGManager
|
| 15 |
+
from core.eval import RAGEvaluator
|
| 16 |
+
from core.session_manager import SessionManager
|
| 17 |
+
from core.session_rag import SessionAwareRAGManager
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.fixture(scope="function")
|
| 21 |
+
def temp_persist_dir():
|
| 22 |
+
"""Create a temporary directory for ChromaDB persistence"""
|
| 23 |
+
temp_dir = tempfile.mkdtemp()
|
| 24 |
+
yield temp_dir
|
| 25 |
+
# Cleanup
|
| 26 |
+
if os.path.exists(temp_dir):
|
| 27 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.fixture(scope="function")
|
| 31 |
+
def rag_manager(temp_persist_dir):
|
| 32 |
+
"""Create a RAGManager instance with temporary persistence"""
|
| 33 |
+
return RAGManager(persist_directory=temp_persist_dir)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.fixture(scope="function")
|
| 37 |
+
def evaluator(rag_manager):
|
| 38 |
+
"""Create a RAGEvaluator instance"""
|
| 39 |
+
return RAGEvaluator(rag_manager)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@pytest.fixture(scope="function")
|
| 43 |
+
def session_manager(temp_persist_dir):
|
| 44 |
+
"""Create a SessionManager instance with temporary persistence"""
|
| 45 |
+
return SessionManager(base_persist_dir=temp_persist_dir)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@pytest.fixture(scope="function")
|
| 49 |
+
def session_rag_manager(rag_manager, session_manager):
|
| 50 |
+
"""Create a SessionAwareRAGManager instance"""
|
| 51 |
+
return SessionAwareRAGManager(rag_manager, session_manager)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@pytest.fixture(scope="function")
|
| 55 |
+
def sample_documents():
|
| 56 |
+
"""Sample documents for testing"""
|
| 57 |
+
return {
|
| 58 |
+
"emergency": {
|
| 59 |
+
"content": """
|
| 60 |
+
Emergency Procedures
|
| 61 |
+
In case of fire, immediately activate the nearest fire alarm and evacuate the building.
|
| 62 |
+
Fire safety protocols require all personnel to know the location of fire extinguishers.
|
| 63 |
+
Do not use elevators during a fire emergency. Stay low to avoid smoke inhalation.
|
| 64 |
+
""",
|
| 65 |
+
"tags": ["fire", "emergency", "safety", "evacuation"],
|
| 66 |
+
"language": "en"
|
| 67 |
+
},
|
| 68 |
+
"medical": {
|
| 69 |
+
"content": """
|
| 70 |
+
Medical Emergency Response
|
| 71 |
+
Medical emergency response begins with assessing patient ABC (Airway, Breathing, Circulation).
|
| 72 |
+
Emergency protocols require immediate notification of medical team.
|
| 73 |
+
If the patient is unresponsive, call for emergency medical services immediately.
|
| 74 |
+
""",
|
| 75 |
+
"tags": ["medical", "emergency", "patient", "response"],
|
| 76 |
+
"language": "en"
|
| 77 |
+
},
|
| 78 |
+
"surgery": {
|
| 79 |
+
"content": """
|
| 80 |
+
Surgical Safety Protocols
|
| 81 |
+
All surgical procedures require pre-operative checklists and sterile environment protocols.
|
| 82 |
+
Surgical safety includes patient identification verification and site marking procedures.
|
| 83 |
+
The surgical site should be marked with an indelible marker before the patient enters the operating room.
|
| 84 |
+
""",
|
| 85 |
+
"tags": ["surgery", "safety", "protocol", "patient"],
|
| 86 |
+
"language": "en"
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@pytest.fixture(scope="function")
|
| 92 |
+
def sample_queries():
|
| 93 |
+
"""Sample evaluation queries for testing"""
|
| 94 |
+
return [
|
| 95 |
+
{
|
| 96 |
+
"query": "What are the emergency procedures for fire incidents?",
|
| 97 |
+
"ground_truth": [
|
| 98 |
+
"In case of fire, immediately activate the nearest fire alarm and evacuate the building.",
|
| 99 |
+
"Fire safety protocols require all personnel to know the location of fire extinguishers."
|
| 100 |
+
],
|
| 101 |
+
"k_values": [1, 3, 5],
|
| 102 |
+
"tags": ["fire", "emergency", "safety"],
|
| 103 |
+
"tag_operator": "OR"
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"query": "How to handle medical emergencies?",
|
| 107 |
+
"ground_truth": [
|
| 108 |
+
"Medical emergency response begins with assessing patient ABC (Airway, Breathing, Circulation).",
|
| 109 |
+
"Emergency protocols require immediate notification of medical team."
|
| 110 |
+
],
|
| 111 |
+
"k_values": [1, 3, 5],
|
| 112 |
+
"tags": ["medical", "emergency", "patient"],
|
| 113 |
+
"tag_operator": "OR"
|
| 114 |
+
}
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@pytest.fixture(scope="function")
|
| 119 |
+
def mcp_server(rag_manager):
|
| 120 |
+
"""Create an MCP server instance for testing"""
|
| 121 |
+
try:
|
| 122 |
+
from app import RAGMCPServer
|
| 123 |
+
# Override the persist_dir in __init__ by modifying the instance
|
| 124 |
+
server = RAGMCPServer()
|
| 125 |
+
server.rag_manager = rag_manager
|
| 126 |
+
server.evaluator = RAGEvaluator(rag_manager)
|
| 127 |
+
return server
|
| 128 |
+
except ImportError:
|
| 129 |
+
pytest.skip("MCP server not available (mcp package not installed)")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@pytest.fixture(scope="function")
|
| 133 |
+
def populated_rag_manager(rag_manager, sample_documents):
|
| 134 |
+
"""Create a RAGManager with sample documents already indexed"""
|
| 135 |
+
from core.ingest import FlatTagChunker
|
| 136 |
+
from core.utils import Chunk
|
| 137 |
+
|
| 138 |
+
# Create sample chunks
|
| 139 |
+
all_chunks = []
|
| 140 |
+
for doc_name, doc_data in sample_documents.items():
|
| 141 |
+
chunker = FlatTagChunker()
|
| 142 |
+
chunks = chunker.chunk_document(
|
| 143 |
+
doc_data["content"],
|
| 144 |
+
language=doc_data["language"],
|
| 145 |
+
user_tags=None
|
| 146 |
+
)
|
| 147 |
+
all_chunks.extend(chunks)
|
| 148 |
+
|
| 149 |
+
# Index chunks
|
| 150 |
+
if all_chunks:
|
| 151 |
+
rag_manager.vector_store.add_documents("documents", all_chunks)
|
| 152 |
+
|
| 153 |
+
return rag_manager
|
| 154 |
+
|
tests/test_accuracy.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Accuracy tests for Auto Tagging RAG System
|
| 3 |
+
Tests tag generation, retrieval accuracy, and evaluation metrics
|
| 4 |
+
"""
|
| 5 |
+
import pytest
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 9 |
+
|
| 10 |
+
from core.tag_generator import TagGenerator
|
| 11 |
+
from core.ingest import FlatTagChunker
|
| 12 |
+
from core.utils import Chunk
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestTagGenerationAccuracy:
|
| 16 |
+
"""Test tag generation accuracy"""
|
| 17 |
+
|
| 18 |
+
def test_english_tag_generation_yake(self):
|
| 19 |
+
"""Test YAKE tag generation for English documents"""
|
| 20 |
+
text = """
|
| 21 |
+
Emergency Procedures
|
| 22 |
+
In case of fire, immediately activate the nearest fire alarm and evacuate the building.
|
| 23 |
+
Fire safety protocols require all personnel to know the location of fire extinguishers.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
generator = TagGenerator()
|
| 27 |
+
tags = generator.generate_tags(text, method="yake", language="en", max_tags=10)
|
| 28 |
+
|
| 29 |
+
assert len(tags) > 0
|
| 30 |
+
assert isinstance(tags, list)
|
| 31 |
+
# Check for relevant tags (case-insensitive)
|
| 32 |
+
tag_str = " ".join(tags).lower()
|
| 33 |
+
assert any(keyword in tag_str for keyword in ["fire", "emergency", "safety"])
|
| 34 |
+
|
| 35 |
+
def test_english_tag_generation_keybert(self):
|
| 36 |
+
"""Test KeyBERT tag generation for English documents"""
|
| 37 |
+
text = """
|
| 38 |
+
Medical Emergency Response
|
| 39 |
+
Medical emergency response begins with assessing patient ABC.
|
| 40 |
+
Emergency protocols require immediate notification of medical team.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
generator = TagGenerator()
|
| 44 |
+
tags = generator.generate_tags(text, method="keybert", language="en", max_tags=10)
|
| 45 |
+
|
| 46 |
+
assert len(tags) > 0
|
| 47 |
+
assert isinstance(tags, list)
|
| 48 |
+
tag_str = " ".join(tags).lower()
|
| 49 |
+
assert any(keyword in tag_str for keyword in ["medical", "emergency", "patient"])
|
| 50 |
+
|
| 51 |
+
def test_english_tag_generation_spacy(self):
|
| 52 |
+
"""Test spaCy tag generation for English documents"""
|
| 53 |
+
text = """
|
| 54 |
+
Surgical Safety Protocols
|
| 55 |
+
All surgical procedures require pre-operative checklists.
|
| 56 |
+
Surgical safety includes patient identification verification.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
generator = TagGenerator()
|
| 60 |
+
try:
|
| 61 |
+
tags = generator.generate_tags(text, method="spacy", language="en", max_tags=10)
|
| 62 |
+
assert len(tags) > 0
|
| 63 |
+
assert isinstance(tags, list)
|
| 64 |
+
except Exception as e:
|
| 65 |
+
pytest.skip(f"spaCy model not available: {e}")
|
| 66 |
+
|
| 67 |
+
def test_japanese_tag_generation(self):
|
| 68 |
+
"""Test Japanese tag generation"""
|
| 69 |
+
text = """
|
| 70 |
+
緊急時の手順
|
| 71 |
+
火災の場合は、最寄りの火災報知器をすぐに作動させ、建物から避難してください。
|
| 72 |
+
防火安全プロトコルでは、すべての職員が消火器の場所を知っている必要があります。
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
generator = TagGenerator()
|
| 76 |
+
tags = generator.generate_tags(text, method="janome", language="ja", max_tags=10)
|
| 77 |
+
|
| 78 |
+
assert len(tags) > 0
|
| 79 |
+
assert isinstance(tags, list)
|
| 80 |
+
|
| 81 |
+
def test_auto_tag_method_selection(self):
|
| 82 |
+
"""Test automatic method selection"""
|
| 83 |
+
generator = TagGenerator()
|
| 84 |
+
|
| 85 |
+
text = "Emergency procedures for fire safety."
|
| 86 |
+
tags = generator.generate_tags(text, method="auto", language="en", max_tags=5)
|
| 87 |
+
|
| 88 |
+
assert len(tags) > 0
|
| 89 |
+
assert isinstance(tags, list)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TestRetrievalAccuracy:
|
| 93 |
+
"""Test retrieval accuracy"""
|
| 94 |
+
|
| 95 |
+
def test_base_rag_retrieval(self, rag_manager, populated_rag_manager):
|
| 96 |
+
"""Test Base RAG retrieval returns relevant documents"""
|
| 97 |
+
query = "What are emergency procedures for fire incidents?"
|
| 98 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 99 |
+
|
| 100 |
+
assert result is not None
|
| 101 |
+
assert len(result.sources) > 0
|
| 102 |
+
assert result.latency > 0
|
| 103 |
+
|
| 104 |
+
# Check if results contain relevant content
|
| 105 |
+
content_lower = " ".join([s['content'].lower() for s in result.sources])
|
| 106 |
+
assert "fire" in content_lower or "emergency" in content_lower
|
| 107 |
+
|
| 108 |
+
def test_tag_filter_rag_or_operator(self, populated_rag_manager):
|
| 109 |
+
"""Test Tag Filter RAG with OR operator"""
|
| 110 |
+
query = "What are emergency procedures?"
|
| 111 |
+
tags = ["fire", "emergency"]
|
| 112 |
+
|
| 113 |
+
result = populated_rag_manager.tag_filter_rag.retrieve(
|
| 114 |
+
query, k=3, tags=tags, tag_operator="OR"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
assert result is not None
|
| 118 |
+
# Should return documents with at least one matching tag
|
| 119 |
+
if len(result.sources) > 0:
|
| 120 |
+
for source in result.sources:
|
| 121 |
+
source_tags = source.get('metadata', {}).get('tags', [])
|
| 122 |
+
if isinstance(source_tags, str):
|
| 123 |
+
source_tags = [t.strip() for t in source_tags.split(',')]
|
| 124 |
+
assert any(tag.lower() in str(source_tags).lower() for tag in tags)
|
| 125 |
+
|
| 126 |
+
def test_tag_filter_rag_and_operator(self, populated_rag_manager):
|
| 127 |
+
"""Test Tag Filter RAG with AND operator"""
|
| 128 |
+
query = "What are emergency procedures?"
|
| 129 |
+
tags = ["emergency", "safety"]
|
| 130 |
+
|
| 131 |
+
result = populated_rag_manager.tag_filter_rag.retrieve(
|
| 132 |
+
query, k=3, tags=tags, tag_operator="AND"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
assert result is not None
|
| 136 |
+
# Should return documents with all matching tags (or empty if none match all)
|
| 137 |
+
|
| 138 |
+
def test_hybrid_rag_retrieval(self, populated_rag_manager):
|
| 139 |
+
"""Test Hybrid RAG retrieval"""
|
| 140 |
+
query = "How to handle medical emergencies?"
|
| 141 |
+
tags = ["medical", "emergency"]
|
| 142 |
+
|
| 143 |
+
result = populated_rag_manager.hybrid_rag.retrieve(
|
| 144 |
+
query, k=3, tags=tags, vector_weight=0.7, tag_weight=0.3
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
assert result is not None
|
| 148 |
+
assert len(result.sources) > 0
|
| 149 |
+
assert result.latency > 0
|
| 150 |
+
|
| 151 |
+
def test_hybrid_rerank_rag_retrieval(self, populated_rag_manager):
|
| 152 |
+
"""Test Hybrid Rerank RAG retrieval"""
|
| 153 |
+
query = "What are surgical safety protocols?"
|
| 154 |
+
tags = ["surgery", "safety"]
|
| 155 |
+
|
| 156 |
+
result = populated_rag_manager.hybrid_rerank_rag.retrieve(
|
| 157 |
+
query, k=3, tags=tags, vector_weight=0.7, tag_weight=0.3
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
assert result is not None
|
| 161 |
+
assert len(result.sources) > 0
|
| 162 |
+
assert result.latency > 0
|
| 163 |
+
|
| 164 |
+
# Reranked results should have rerank_score or hybrid_score
|
| 165 |
+
if len(result.sources) > 0:
|
| 166 |
+
first_source = result.sources[0]
|
| 167 |
+
assert 'score' in first_source
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class TestEvaluationMetrics:
|
| 171 |
+
"""Test evaluation metrics accuracy"""
|
| 172 |
+
|
| 173 |
+
def test_precision_at_k_calculation(self, evaluator, populated_rag_manager, sample_queries):
|
| 174 |
+
"""Test Precision@k calculation"""
|
| 175 |
+
query_data = sample_queries[0]
|
| 176 |
+
query = query_data["query"]
|
| 177 |
+
ground_truth = query_data["ground_truth"]
|
| 178 |
+
|
| 179 |
+
# Retrieve results
|
| 180 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 181 |
+
|
| 182 |
+
# Calculate metrics manually
|
| 183 |
+
retrieved = [s['content'] for s in result.sources]
|
| 184 |
+
|
| 185 |
+
# Calculate precision@k
|
| 186 |
+
relevant_retrieved = 0
|
| 187 |
+
for gt in ground_truth:
|
| 188 |
+
for ret in retrieved:
|
| 189 |
+
if gt.lower() in ret.lower() or ret.lower() in gt.lower():
|
| 190 |
+
relevant_retrieved += 1
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
precision = relevant_retrieved / len(retrieved) if retrieved else 0
|
| 194 |
+
|
| 195 |
+
assert 0 <= precision <= 1
|
| 196 |
+
assert precision >= 0 # Should be non-negative
|
| 197 |
+
|
| 198 |
+
def test_evaluation_batch_evaluate(self, evaluator, populated_rag_manager, sample_queries):
|
| 199 |
+
"""Test batch evaluation produces correct metrics"""
|
| 200 |
+
df, summary, results = evaluator.batch_evaluate(
|
| 201 |
+
sample_queries,
|
| 202 |
+
output_file=None,
|
| 203 |
+
pipelines=['base_rag', 'tag_filter_rag', 'hybrid_rag', 'hybrid_rerank_rag']
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
assert df is not None
|
| 207 |
+
assert len(df) > 0
|
| 208 |
+
|
| 209 |
+
# Check required columns
|
| 210 |
+
required_columns = ['query', 'k', 'pipeline', 'precision_at_k', 'ndcg_at_k', 'mrr']
|
| 211 |
+
for col in required_columns:
|
| 212 |
+
assert col in df.columns
|
| 213 |
+
|
| 214 |
+
# Check summary structure
|
| 215 |
+
assert isinstance(summary, dict)
|
| 216 |
+
assert 'summary_stats' in summary or isinstance(summary, dict)
|
| 217 |
+
|
| 218 |
+
def test_metrics_ranges(self, evaluator, populated_rag_manager, sample_queries):
|
| 219 |
+
"""Test that all metrics are in valid ranges"""
|
| 220 |
+
df, summary, results = evaluator.batch_evaluate(
|
| 221 |
+
sample_queries[:1], # Use one query for speed
|
| 222 |
+
output_file=None,
|
| 223 |
+
pipelines=['base_rag']
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Check metric ranges
|
| 227 |
+
assert all(0 <= val <= 1 for val in df['precision_at_k'].dropna())
|
| 228 |
+
assert all(0 <= val <= 1 for val in df['ndcg_at_k'].dropna())
|
| 229 |
+
assert all(0 <= val <= 1 for val in df['mrr'].dropna())
|
| 230 |
+
assert all(0 <= val <= 1 for val in df['hit_at_k'].dropna())
|
| 231 |
+
|
tests/test_japanese_support.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Japanese language support tests for Auto Tagging RAG System
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 8 |
+
|
| 9 |
+
from core.tag_generator import TagGenerator
|
| 10 |
+
from core.ingest import FlatTagChunker
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestJapaneseLanguageSupport:
|
| 14 |
+
"""Test Japanese language processing"""
|
| 15 |
+
|
| 16 |
+
def test_japanese_tag_generation(self):
|
| 17 |
+
"""Test Japanese tag generation using Janome"""
|
| 18 |
+
text = """
|
| 19 |
+
緊急時の手順
|
| 20 |
+
火災の場合は、最寄りの火災報知器をすぐに作動させ、建物から避難してください。
|
| 21 |
+
防火安全プロトコルでは、すべての職員が消火器の場所を知っている必要があります。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
generator = TagGenerator()
|
| 25 |
+
tags = generator.generate_tags(text, method="janome", language="ja", max_tags=10)
|
| 26 |
+
|
| 27 |
+
assert len(tags) > 0
|
| 28 |
+
assert isinstance(tags, list)
|
| 29 |
+
|
| 30 |
+
def test_japanese_language_detection(self):
|
| 31 |
+
"""Test Japanese language auto-detection"""
|
| 32 |
+
from core.tag_generator import TagGenerator
|
| 33 |
+
|
| 34 |
+
text = "火災の場合は、すぐに避難してください。"
|
| 35 |
+
|
| 36 |
+
generator = TagGenerator()
|
| 37 |
+
# Should detect Japanese and use appropriate method
|
| 38 |
+
tags = generator.generate_tags(text, method="auto", language=None, max_tags=5)
|
| 39 |
+
|
| 40 |
+
assert len(tags) > 0
|
| 41 |
+
|
| 42 |
+
def test_japanese_document_processing(self, rag_manager):
|
| 43 |
+
"""Test processing Japanese documents"""
|
| 44 |
+
text = """
|
| 45 |
+
医療緊急対応
|
| 46 |
+
医療緊急対応は、患者のABC(気道、呼吸、循環)を評価することから始まります。
|
| 47 |
+
緊急プロトコルでは、医療チームへの即座の通知が必要です。
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
chunker = FlatTagChunker()
|
| 51 |
+
chunks = chunker.chunk_document(text, language="ja", user_tags=None)
|
| 52 |
+
|
| 53 |
+
assert len(chunks) > 0
|
| 54 |
+
|
| 55 |
+
# Index Japanese documents
|
| 56 |
+
rag_manager.vector_store.add_documents("documents", chunks)
|
| 57 |
+
stats = rag_manager.vector_store.get_collection_stats("documents")
|
| 58 |
+
|
| 59 |
+
assert stats["chunk_count"] > 0
|
| 60 |
+
|
| 61 |
+
def test_japanese_search(self, rag_manager):
|
| 62 |
+
"""Test searching Japanese documents"""
|
| 63 |
+
# Index Japanese document
|
| 64 |
+
text = "火災安全プロトコル"
|
| 65 |
+
chunker = FlatTagChunker()
|
| 66 |
+
chunks = chunker.chunk_document(text, language="ja", user_tags=None)
|
| 67 |
+
|
| 68 |
+
if chunks:
|
| 69 |
+
rag_manager.vector_store.add_documents("documents", chunks)
|
| 70 |
+
|
| 71 |
+
# Search with Japanese query
|
| 72 |
+
query = "火災"
|
| 73 |
+
result = rag_manager.base_rag.retrieve(query, k=3)
|
| 74 |
+
|
| 75 |
+
assert result is not None
|
| 76 |
+
|
| 77 |
+
def test_mixed_language_documents(self, rag_manager):
|
| 78 |
+
"""Test handling of mixed language documents"""
|
| 79 |
+
text = "Emergency 緊急 Fire 火災 Safety 安全"
|
| 80 |
+
|
| 81 |
+
chunker = FlatTagChunker()
|
| 82 |
+
chunks = chunker.chunk_document(text, language=None, user_tags=None)
|
| 83 |
+
|
| 84 |
+
# Should handle mixed content gracefully
|
| 85 |
+
assert chunks is not None
|
| 86 |
+
assert isinstance(chunks, list)
|
| 87 |
+
|
tests/test_mcp_server.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test cases for MCP Server functionality
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
import json
|
| 6 |
+
import asyncio
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Import core modules
|
| 10 |
+
import sys
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 12 |
+
|
| 13 |
+
# Check if MCP is available
|
| 14 |
+
try:
|
| 15 |
+
import mcp
|
| 16 |
+
MCP_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
MCP_AVAILABLE = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@pytest.mark.asyncio
|
| 22 |
+
@pytest.mark.skipif(not MCP_AVAILABLE, reason="MCP package not installed")
|
| 23 |
+
class TestMCPServer:
|
| 24 |
+
"""Test cases for MCP Server"""
|
| 25 |
+
|
| 26 |
+
async def test_list_tools(self, mcp_server):
|
| 27 |
+
"""Test that MCP server lists available tools"""
|
| 28 |
+
tools = await mcp_server.list_tools()
|
| 29 |
+
|
| 30 |
+
assert len(tools) >= 2
|
| 31 |
+
tool_names = [tool.name for tool in tools]
|
| 32 |
+
assert "search_documents" in tool_names
|
| 33 |
+
assert "evaluate_retrieval" in tool_names
|
| 34 |
+
|
| 35 |
+
async def test_search_documents_base_rag(self, mcp_server, populated_rag_manager):
|
| 36 |
+
"""Test search_documents tool with Base RAG pipeline"""
|
| 37 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 38 |
+
|
| 39 |
+
arguments = {
|
| 40 |
+
"query": "What are emergency procedures for fire?",
|
| 41 |
+
"k": 3,
|
| 42 |
+
"pipeline": "base_rag"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 46 |
+
|
| 47 |
+
assert len(result) > 0
|
| 48 |
+
response = json.loads(result[0].text)
|
| 49 |
+
|
| 50 |
+
assert "content" in response
|
| 51 |
+
assert "sources" in response
|
| 52 |
+
assert "latency" in response
|
| 53 |
+
assert response["strategy"] == "base_rag"
|
| 54 |
+
assert len(response["sources"]) <= 3
|
| 55 |
+
|
| 56 |
+
async def test_search_documents_tag_filter_rag(self, mcp_server, populated_rag_manager):
|
| 57 |
+
"""Test search_documents tool with Tag Filter RAG pipeline"""
|
| 58 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 59 |
+
|
| 60 |
+
arguments = {
|
| 61 |
+
"query": "What are emergency procedures?",
|
| 62 |
+
"k": 3,
|
| 63 |
+
"pipeline": "tag_filter_rag",
|
| 64 |
+
"tags": ["fire", "emergency"],
|
| 65 |
+
"tag_operator": "OR"
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 69 |
+
|
| 70 |
+
assert len(result) > 0
|
| 71 |
+
response = json.loads(result[0].text)
|
| 72 |
+
|
| 73 |
+
assert "content" in response
|
| 74 |
+
assert "sources" in response
|
| 75 |
+
assert response["strategy"] == "tag_filter_rag"
|
| 76 |
+
|
| 77 |
+
async def test_search_documents_hybrid_rag(self, mcp_server, populated_rag_manager):
|
| 78 |
+
"""Test search_documents tool with Hybrid RAG pipeline"""
|
| 79 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 80 |
+
|
| 81 |
+
arguments = {
|
| 82 |
+
"query": "How to handle medical emergencies?",
|
| 83 |
+
"k": 3,
|
| 84 |
+
"pipeline": "hybrid_rag",
|
| 85 |
+
"tags": ["medical", "emergency"]
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 89 |
+
|
| 90 |
+
assert len(result) > 0
|
| 91 |
+
response = json.loads(result[0].text)
|
| 92 |
+
|
| 93 |
+
assert "content" in response
|
| 94 |
+
assert "sources" in response
|
| 95 |
+
assert response["strategy"] == "hybrid_rag"
|
| 96 |
+
|
| 97 |
+
async def test_search_documents_hybrid_rerank_rag(self, mcp_server, populated_rag_manager):
|
| 98 |
+
"""Test search_documents tool with Hybrid Rerank RAG pipeline"""
|
| 99 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 100 |
+
|
| 101 |
+
arguments = {
|
| 102 |
+
"query": "What are surgical safety protocols?",
|
| 103 |
+
"k": 3,
|
| 104 |
+
"pipeline": "hybrid_rerank_rag",
|
| 105 |
+
"tags": ["surgery", "safety"]
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 109 |
+
|
| 110 |
+
assert len(result) > 0
|
| 111 |
+
response = json.loads(result[0].text)
|
| 112 |
+
|
| 113 |
+
assert "content" in response
|
| 114 |
+
assert "sources" in response
|
| 115 |
+
assert response["strategy"] == "hybrid_rerank_rag"
|
| 116 |
+
|
| 117 |
+
async def test_search_documents_default_parameters(self, mcp_server, populated_rag_manager):
|
| 118 |
+
"""Test search_documents tool with default parameters"""
|
| 119 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 120 |
+
|
| 121 |
+
arguments = {
|
| 122 |
+
"query": "What are emergency procedures?"
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 126 |
+
|
| 127 |
+
assert len(result) > 0
|
| 128 |
+
response = json.loads(result[0].text)
|
| 129 |
+
|
| 130 |
+
assert "content" in response
|
| 131 |
+
assert "sources" in response
|
| 132 |
+
assert len(response["sources"]) <= 5 # Default k=5
|
| 133 |
+
|
| 134 |
+
async def test_search_documents_invalid_pipeline(self, mcp_server, populated_rag_manager):
|
| 135 |
+
"""Test search_documents tool with invalid pipeline falls back to base_rag"""
|
| 136 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 137 |
+
|
| 138 |
+
arguments = {
|
| 139 |
+
"query": "What are emergency procedures?",
|
| 140 |
+
"pipeline": "invalid_pipeline"
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 144 |
+
|
| 145 |
+
assert len(result) > 0
|
| 146 |
+
response = json.loads(result[0].text)
|
| 147 |
+
|
| 148 |
+
# Should fall back to base_rag
|
| 149 |
+
assert "content" in response
|
| 150 |
+
assert "sources" in response
|
| 151 |
+
|
| 152 |
+
async def test_evaluate_retrieval(self, mcp_server, populated_rag_manager, sample_queries):
|
| 153 |
+
"""Test evaluate_retrieval tool"""
|
| 154 |
+
from core.eval import RAGEvaluator
|
| 155 |
+
|
| 156 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 157 |
+
mcp_server.evaluator = RAGEvaluator(populated_rag_manager)
|
| 158 |
+
|
| 159 |
+
arguments = {
|
| 160 |
+
"queries": sample_queries,
|
| 161 |
+
"output_file": "test_evaluation.json"
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
result = await mcp_server.call_tool("evaluate_retrieval", arguments)
|
| 165 |
+
|
| 166 |
+
assert len(result) > 0
|
| 167 |
+
response = json.loads(result[0].text)
|
| 168 |
+
|
| 169 |
+
assert "summary" in response
|
| 170 |
+
assert "total_queries" in response
|
| 171 |
+
assert response["total_queries"] == len(sample_queries)
|
| 172 |
+
assert len(response["summary"]) > 0
|
| 173 |
+
|
| 174 |
+
async def test_evaluate_retrieval_no_output_file(self, mcp_server, populated_rag_manager, sample_queries):
|
| 175 |
+
"""Test evaluate_retrieval tool without output file"""
|
| 176 |
+
from core.eval import RAGEvaluator
|
| 177 |
+
|
| 178 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 179 |
+
mcp_server.evaluator = RAGEvaluator(populated_rag_manager)
|
| 180 |
+
|
| 181 |
+
arguments = {
|
| 182 |
+
"queries": sample_queries
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
result = await mcp_server.call_tool("evaluate_retrieval", arguments)
|
| 186 |
+
|
| 187 |
+
assert len(result) > 0
|
| 188 |
+
response = json.loads(result[0].text)
|
| 189 |
+
|
| 190 |
+
assert "summary" in response
|
| 191 |
+
assert "total_queries" in response
|
| 192 |
+
|
| 193 |
+
async def test_call_tool_invalid_tool(self, mcp_server):
|
| 194 |
+
"""Test calling invalid tool raises error"""
|
| 195 |
+
with pytest.raises(ValueError, match="Unknown tool"):
|
| 196 |
+
await mcp_server.call_tool("invalid_tool", {})
|
| 197 |
+
|
| 198 |
+
async def test_search_documents_empty_query(self, mcp_server, populated_rag_manager):
|
| 199 |
+
"""Test search_documents with empty query"""
|
| 200 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 201 |
+
|
| 202 |
+
arguments = {
|
| 203 |
+
"query": "",
|
| 204 |
+
"k": 3
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
# Should not raise error, but may return empty results
|
| 208 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 209 |
+
assert len(result) > 0
|
| 210 |
+
|
| 211 |
+
async def test_search_documents_tag_operators(self, mcp_server, populated_rag_manager):
|
| 212 |
+
"""Test search_documents with different tag operators"""
|
| 213 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 214 |
+
|
| 215 |
+
operators = ["OR", "AND", "NOT"]
|
| 216 |
+
for operator in operators:
|
| 217 |
+
arguments = {
|
| 218 |
+
"query": "What are emergency procedures?",
|
| 219 |
+
"k": 3,
|
| 220 |
+
"pipeline": "tag_filter_rag",
|
| 221 |
+
"tags": ["fire", "emergency"],
|
| 222 |
+
"tag_operator": operator
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 226 |
+
assert len(result) > 0
|
| 227 |
+
response = json.loads(result[0].text)
|
| 228 |
+
assert "sources" in response
|
| 229 |
+
|
| 230 |
+
async def test_mcp_server_initialization(self, temp_persist_dir):
|
| 231 |
+
"""Test MCP server initialization"""
|
| 232 |
+
try:
|
| 233 |
+
from app import RAGMCPServer
|
| 234 |
+
server = RAGMCPServer()
|
| 235 |
+
|
| 236 |
+
assert server.rag_manager is not None
|
| 237 |
+
assert server.evaluator is not None
|
| 238 |
+
|
| 239 |
+
# Test list_tools
|
| 240 |
+
tools = await server.list_tools()
|
| 241 |
+
assert len(tools) >= 2
|
| 242 |
+
except ImportError:
|
| 243 |
+
pytest.skip("MCP server not available")
|
| 244 |
+
|
| 245 |
+
async def test_search_documents_all_pipelines(self, mcp_server, populated_rag_manager):
|
| 246 |
+
"""Test search_documents with all pipeline types"""
|
| 247 |
+
mcp_server.rag_manager = populated_rag_manager
|
| 248 |
+
|
| 249 |
+
pipelines = ["base_rag", "tag_filter_rag", "hybrid_rag", "hybrid_rerank_rag"]
|
| 250 |
+
|
| 251 |
+
for pipeline in pipelines:
|
| 252 |
+
arguments = {
|
| 253 |
+
"query": "What are emergency procedures?",
|
| 254 |
+
"k": 3,
|
| 255 |
+
"pipeline": pipeline,
|
| 256 |
+
"tags": ["emergency"]
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
result = await mcp_server.call_tool("search_documents", arguments)
|
| 260 |
+
assert len(result) > 0
|
| 261 |
+
response = json.loads(result[0].text)
|
| 262 |
+
assert response["strategy"] == pipeline
|
| 263 |
+
|
tests/test_robustness.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Robustness tests for Auto Tagging RAG System
|
| 3 |
+
Tests error handling, edge cases, and data integrity
|
| 4 |
+
"""
|
| 5 |
+
import pytest
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestErrorHandling:
|
| 12 |
+
"""Test error handling and recovery"""
|
| 13 |
+
|
| 14 |
+
def test_empty_query_handling(self, populated_rag_manager):
|
| 15 |
+
"""Test handling of empty queries"""
|
| 16 |
+
query = ""
|
| 17 |
+
|
| 18 |
+
# Should not crash, may return empty results
|
| 19 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 20 |
+
assert result is not None
|
| 21 |
+
|
| 22 |
+
def test_invalid_k_value(self, populated_rag_manager):
|
| 23 |
+
"""Test handling of invalid k values"""
|
| 24 |
+
query = "test query"
|
| 25 |
+
|
| 26 |
+
# Test with k=0
|
| 27 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=0)
|
| 28 |
+
assert result is not None
|
| 29 |
+
|
| 30 |
+
# Test with negative k
|
| 31 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=-1)
|
| 32 |
+
assert result is not None
|
| 33 |
+
|
| 34 |
+
def test_missing_tags_in_tag_filter(self, populated_rag_manager):
|
| 35 |
+
"""Test tag filter with tags that don't exist"""
|
| 36 |
+
query = "test query"
|
| 37 |
+
|
| 38 |
+
result = populated_rag_manager.tag_filter_rag.retrieve(
|
| 39 |
+
query, k=3, tags=["nonexistent-tag-xyz"], tag_operator="OR"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
assert result is not None
|
| 43 |
+
# May return empty results, but shouldn't crash
|
| 44 |
+
|
| 45 |
+
def test_invalid_tag_operator(self, populated_rag_manager):
|
| 46 |
+
"""Test handling of invalid tag operator"""
|
| 47 |
+
query = "test query"
|
| 48 |
+
|
| 49 |
+
# Should default to OR or handle gracefully
|
| 50 |
+
result = populated_rag_manager.tag_filter_rag.retrieve(
|
| 51 |
+
query, k=3, tags=["emergency"], tag_operator="INVALID"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
assert result is not None
|
| 55 |
+
|
| 56 |
+
def test_empty_document_handling(self, rag_manager):
|
| 57 |
+
"""Test handling of empty documents"""
|
| 58 |
+
from core.ingest import FlatTagChunker
|
| 59 |
+
|
| 60 |
+
chunker = FlatTagChunker()
|
| 61 |
+
|
| 62 |
+
# Empty document
|
| 63 |
+
chunks = chunker.chunk_document("", language="en", user_tags=None)
|
| 64 |
+
|
| 65 |
+
# Should handle gracefully (may return empty list)
|
| 66 |
+
assert chunks is not None
|
| 67 |
+
assert isinstance(chunks, list)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TestEdgeCases:
|
| 71 |
+
"""Test edge cases and boundary conditions"""
|
| 72 |
+
|
| 73 |
+
def test_very_short_document(self, rag_manager):
|
| 74 |
+
"""Test processing of very short documents"""
|
| 75 |
+
from core.ingest import FlatTagChunker
|
| 76 |
+
|
| 77 |
+
chunker = FlatTagChunker()
|
| 78 |
+
chunks = chunker.chunk_document("Emergency!", language="en", user_tags=None)
|
| 79 |
+
|
| 80 |
+
assert chunks is not None
|
| 81 |
+
assert isinstance(chunks, list)
|
| 82 |
+
|
| 83 |
+
def test_special_characters_in_document(self, rag_manager):
|
| 84 |
+
"""Test handling of special characters"""
|
| 85 |
+
from core.ingest import FlatTagChunker
|
| 86 |
+
|
| 87 |
+
text = "Emergency! 🚨 Fire safety (protocol #1) requires: 1) Alert 2) Evacuate"
|
| 88 |
+
chunker = FlatTagChunker()
|
| 89 |
+
chunks = chunker.chunk_document(text, language="en", user_tags=None)
|
| 90 |
+
|
| 91 |
+
assert chunks is not None
|
| 92 |
+
assert len(chunks) > 0
|
| 93 |
+
|
| 94 |
+
def test_large_k_value(self, populated_rag_manager):
|
| 95 |
+
"""Test retrieval with large k value"""
|
| 96 |
+
query = "test query"
|
| 97 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=100)
|
| 98 |
+
|
| 99 |
+
assert result is not None
|
| 100 |
+
# Should not crash, may return fewer results than requested
|
| 101 |
+
|
| 102 |
+
def test_many_tags(self, populated_rag_manager):
|
| 103 |
+
"""Test tag filtering with many tags"""
|
| 104 |
+
query = "test query"
|
| 105 |
+
many_tags = [f"tag_{i}" for i in range(50)]
|
| 106 |
+
|
| 107 |
+
result = populated_rag_manager.tag_filter_rag.retrieve(
|
| 108 |
+
query, k=3, tags=many_tags, tag_operator="OR"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
assert result is not None
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class TestDataIntegrity:
|
| 115 |
+
"""Test data integrity and consistency"""
|
| 116 |
+
|
| 117 |
+
def test_document_count_accuracy(self, rag_manager, sample_documents):
|
| 118 |
+
"""Test document count reflects unique documents"""
|
| 119 |
+
from core.ingest import FlatTagChunker
|
| 120 |
+
|
| 121 |
+
# Index multiple chunks from same document
|
| 122 |
+
all_chunks = []
|
| 123 |
+
doc_data = sample_documents["emergency"]
|
| 124 |
+
|
| 125 |
+
chunker = FlatTagChunker()
|
| 126 |
+
# Create chunks multiple times to simulate chunking
|
| 127 |
+
for _ in range(2):
|
| 128 |
+
chunks = chunker.chunk_document(
|
| 129 |
+
doc_data["content"],
|
| 130 |
+
language=doc_data["language"],
|
| 131 |
+
user_tags=None
|
| 132 |
+
)
|
| 133 |
+
all_chunks.extend(chunks)
|
| 134 |
+
|
| 135 |
+
if all_chunks:
|
| 136 |
+
rag_manager.vector_store.add_documents("documents", all_chunks)
|
| 137 |
+
stats = rag_manager.vector_store.get_collection_stats("documents")
|
| 138 |
+
|
| 139 |
+
# Should count unique documents, not chunks
|
| 140 |
+
assert stats["document_count"] > 0
|
| 141 |
+
|
| 142 |
+
def test_tag_consistency(self, rag_manager):
|
| 143 |
+
"""Test tag generation consistency"""
|
| 144 |
+
from core.tag_generator import TagGenerator
|
| 145 |
+
|
| 146 |
+
text = "Emergency procedures for fire safety and evacuation protocols."
|
| 147 |
+
generator = TagGenerator()
|
| 148 |
+
|
| 149 |
+
tags1 = generator.generate_tags(text, method="yake", language="en", max_tags=5)
|
| 150 |
+
tags2 = generator.generate_tags(text, method="yake", language="en", max_tags=5)
|
| 151 |
+
|
| 152 |
+
# Tags should be similar (may vary slightly due to randomness)
|
| 153 |
+
assert len(tags1) > 0
|
| 154 |
+
assert len(tags2) > 0
|
| 155 |
+
|
| 156 |
+
def test_session_isolation(self, session_manager):
|
| 157 |
+
"""Test session isolation"""
|
| 158 |
+
# Create two sessions
|
| 159 |
+
session1 = session_manager.create_session(user_id="user1")
|
| 160 |
+
session2 = session_manager.create_session(user_id="user2")
|
| 161 |
+
|
| 162 |
+
assert session1.session_id != session2.session_id
|
| 163 |
+
assert session1.collection_name != session2.collection_name
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class TestPerformance:
|
| 167 |
+
"""Test performance and resource usage"""
|
| 168 |
+
|
| 169 |
+
def test_retrieval_latency(self, populated_rag_manager):
|
| 170 |
+
"""Test retrieval latency is reasonable"""
|
| 171 |
+
query = "What are emergency procedures?"
|
| 172 |
+
|
| 173 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 174 |
+
|
| 175 |
+
# Should complete within reasonable time (10 seconds for test)
|
| 176 |
+
assert result.latency < 10.0
|
| 177 |
+
assert result.latency > 0
|
| 178 |
+
|
| 179 |
+
def test_evaluation_performance(self, evaluator, populated_rag_manager, sample_queries):
|
| 180 |
+
"""Test evaluation completes in reasonable time"""
|
| 181 |
+
import time
|
| 182 |
+
|
| 183 |
+
start = time.time()
|
| 184 |
+
df, summary, results = evaluator.batch_evaluate(
|
| 185 |
+
sample_queries[:2], # Use 2 queries
|
| 186 |
+
output_file=None,
|
| 187 |
+
pipelines=['base_rag']
|
| 188 |
+
)
|
| 189 |
+
elapsed = time.time() - start
|
| 190 |
+
|
| 191 |
+
# Should complete within 60 seconds for 2 queries
|
| 192 |
+
assert elapsed < 60.0
|
| 193 |
+
assert df is not None
|
| 194 |
+
|
tests/test_user_scenarios.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Non-technical user scenario tests for Auto Tagging RAG System
|
| 3 |
+
Tests workflows that don't require technical knowledge
|
| 4 |
+
"""
|
| 5 |
+
import pytest
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestNonTechnicalUserScenarios:
|
| 12 |
+
"""Test scenarios for non-technical users"""
|
| 13 |
+
|
| 14 |
+
def test_first_time_user_upload(self, rag_manager):
|
| 15 |
+
"""Test first-time user can upload documents"""
|
| 16 |
+
from core.ingest import FlatTagChunker
|
| 17 |
+
|
| 18 |
+
# Simulate user action: upload document with default settings
|
| 19 |
+
text = """
|
| 20 |
+
Emergency Procedures
|
| 21 |
+
In case of fire, activate fire alarm and evacuate.
|
| 22 |
+
Know the location of fire extinguishers.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# User doesn't specify language (Auto-detect)
|
| 26 |
+
chunker = FlatTagChunker()
|
| 27 |
+
chunks = chunker.chunk_document(text, language=None, user_tags=None)
|
| 28 |
+
|
| 29 |
+
assert len(chunks) > 0
|
| 30 |
+
|
| 31 |
+
# User builds index
|
| 32 |
+
rag_manager.vector_store.add_documents("documents", chunks)
|
| 33 |
+
|
| 34 |
+
# Verify it worked
|
| 35 |
+
stats = rag_manager.vector_store.get_collection_stats("documents")
|
| 36 |
+
assert stats["chunk_count"] > 0
|
| 37 |
+
|
| 38 |
+
def test_simple_search_query(self, populated_rag_manager):
|
| 39 |
+
"""Test non-technical user can search without understanding technical details"""
|
| 40 |
+
# User enters simple query
|
| 41 |
+
query = "What are emergency procedures?"
|
| 42 |
+
|
| 43 |
+
# Uses default pipeline (Base RAG)
|
| 44 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 45 |
+
|
| 46 |
+
# User sees results
|
| 47 |
+
assert result is not None
|
| 48 |
+
assert len(result.sources) > 0
|
| 49 |
+
assert result.content # Results are readable
|
| 50 |
+
|
| 51 |
+
def test_search_without_tags(self, populated_rag_manager):
|
| 52 |
+
"""Test user can search without understanding tags"""
|
| 53 |
+
query = "How to handle medical emergencies?"
|
| 54 |
+
|
| 55 |
+
# User doesn't provide tags, uses default pipeline
|
| 56 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 57 |
+
|
| 58 |
+
assert result is not None
|
| 59 |
+
assert len(result.sources) > 0
|
| 60 |
+
|
| 61 |
+
def test_evaluation_with_sample_queries(self, evaluator, populated_rag_manager):
|
| 62 |
+
"""Test user can run evaluation with sample queries"""
|
| 63 |
+
# User copies sample queries (simplified)
|
| 64 |
+
queries = [
|
| 65 |
+
{
|
| 66 |
+
"query": "What are emergency procedures for fire?",
|
| 67 |
+
"ground_truth": ["Fire safety protocols"],
|
| 68 |
+
"k_values": [1, 3, 5]
|
| 69 |
+
}
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
# User runs evaluation
|
| 73 |
+
df, summary, results = evaluator.batch_evaluate(
|
| 74 |
+
queries,
|
| 75 |
+
output_file=None,
|
| 76 |
+
pipelines=['base_rag']
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# User sees results
|
| 80 |
+
assert df is not None
|
| 81 |
+
assert len(df) > 0
|
| 82 |
+
|
| 83 |
+
def test_user_adds_custom_tags(self, rag_manager):
|
| 84 |
+
"""Test user can add custom tags without understanding auto-tagging"""
|
| 85 |
+
from core.ingest import FlatTagChunker
|
| 86 |
+
|
| 87 |
+
text = "Emergency procedures document."
|
| 88 |
+
|
| 89 |
+
# User adds custom tags
|
| 90 |
+
user_tags = ["important", "must-read"]
|
| 91 |
+
|
| 92 |
+
chunker = FlatTagChunker()
|
| 93 |
+
chunks = chunker.chunk_document(text, language="en", user_tags=user_tags)
|
| 94 |
+
|
| 95 |
+
# Verify user tags are included
|
| 96 |
+
if chunks:
|
| 97 |
+
chunk_tags = chunks[0].metadata.get('tags', [])
|
| 98 |
+
if isinstance(chunk_tags, str):
|
| 99 |
+
chunk_tags = [t.strip() for t in chunk_tags.split(',')]
|
| 100 |
+
|
| 101 |
+
chunk_tags_lower = [t.lower() for t in chunk_tags]
|
| 102 |
+
assert any(ut.lower() in chunk_tags_lower for ut in user_tags)
|
| 103 |
+
|
| 104 |
+
def test_session_persistence_for_user(self, session_manager):
|
| 105 |
+
"""Test session persists across browser refresh for non-technical user"""
|
| 106 |
+
# User creates session (automatic)
|
| 107 |
+
session = session_manager.create_session(user_id="user_123")
|
| 108 |
+
session_id = session.session_id
|
| 109 |
+
|
| 110 |
+
# User refreshes browser - session should persist
|
| 111 |
+
retrieved = session_manager.get_session(session_id)
|
| 112 |
+
|
| 113 |
+
assert retrieved is not None
|
| 114 |
+
assert retrieved.session_id == session_id
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class TestRealWorldScenarios:
|
| 118 |
+
"""Test real-world usage scenarios"""
|
| 119 |
+
|
| 120 |
+
def test_document_processing_workflow(self, rag_manager, sample_documents):
|
| 121 |
+
"""Test complete document processing workflow"""
|
| 122 |
+
from core.ingest import FlatTagChunker
|
| 123 |
+
|
| 124 |
+
# User uploads multiple documents
|
| 125 |
+
all_chunks = []
|
| 126 |
+
for doc_name, doc_data in sample_documents.items():
|
| 127 |
+
chunker = FlatTagChunker()
|
| 128 |
+
chunks = chunker.chunk_document(
|
| 129 |
+
doc_data["content"],
|
| 130 |
+
language=doc_data["language"],
|
| 131 |
+
user_tags=None
|
| 132 |
+
)
|
| 133 |
+
all_chunks.extend(chunks)
|
| 134 |
+
|
| 135 |
+
# User builds index
|
| 136 |
+
if all_chunks:
|
| 137 |
+
rag_manager.vector_store.add_documents("documents", all_chunks)
|
| 138 |
+
stats = rag_manager.vector_store.get_collection_stats("documents")
|
| 139 |
+
|
| 140 |
+
assert stats["chunk_count"] > 0
|
| 141 |
+
assert stats["document_count"] > 0
|
| 142 |
+
|
| 143 |
+
def test_search_comparison_workflow(self, populated_rag_manager):
|
| 144 |
+
"""Test comparing different search methods"""
|
| 145 |
+
query = "What are emergency procedures?"
|
| 146 |
+
|
| 147 |
+
# User compares methods
|
| 148 |
+
results = {}
|
| 149 |
+
|
| 150 |
+
results['base'] = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 151 |
+
results['tag'] = populated_rag_manager.tag_filter_rag.retrieve(
|
| 152 |
+
query, k=3, tags=["emergency"], tag_operator="OR"
|
| 153 |
+
)
|
| 154 |
+
results['hybrid'] = populated_rag_manager.hybrid_rag.retrieve(
|
| 155 |
+
query, k=3, tags=["emergency"], vector_weight=0.7, tag_weight=0.3
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# All should return results
|
| 159 |
+
for method, result in results.items():
|
| 160 |
+
assert result is not None
|
| 161 |
+
assert result.latency > 0
|
| 162 |
+
|
| 163 |
+
def test_chat_interface_workflow(self, populated_rag_manager):
|
| 164 |
+
"""Test natural conversation workflow"""
|
| 165 |
+
query = "Tell me about fire safety"
|
| 166 |
+
|
| 167 |
+
# User asks question
|
| 168 |
+
result = populated_rag_manager.base_rag.retrieve(query, k=3)
|
| 169 |
+
|
| 170 |
+
# User sees answer and sources
|
| 171 |
+
assert result is not None
|
| 172 |
+
assert result.content
|
| 173 |
+
assert len(result.sources) > 0
|
| 174 |
+
|
tests/test_ux.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User Experience tests for Auto Tagging RAG System
|
| 3 |
+
Tests UI workflows, user interactions, and usability
|
| 4 |
+
"""
|
| 5 |
+
import pytest
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 9 |
+
|
| 10 |
+
from core.ingest import FlatTagChunker
|
| 11 |
+
from core.utils import Chunk
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestUserWorkflows:
|
| 15 |
+
"""Test user workflows and interactions"""
|
| 16 |
+
|
| 17 |
+
def test_document_upload_workflow(self, rag_manager, sample_documents):
|
| 18 |
+
"""Test complete document upload workflow"""
|
| 19 |
+
doc_data = sample_documents["emergency"]
|
| 20 |
+
|
| 21 |
+
# Simulate upload workflow
|
| 22 |
+
chunker = FlatTagChunker()
|
| 23 |
+
chunks = chunker.chunk_document(
|
| 24 |
+
doc_data["content"],
|
| 25 |
+
language=doc_data["language"],
|
| 26 |
+
user_tags=None
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
assert len(chunks) > 0
|
| 30 |
+
|
| 31 |
+
# Index chunks
|
| 32 |
+
rag_manager.vector_store.add_documents("documents", chunks)
|
| 33 |
+
|
| 34 |
+
# Verify indexing
|
| 35 |
+
stats = rag_manager.vector_store.get_collection_stats("documents")
|
| 36 |
+
assert stats["chunk_count"] >= len(chunks)
|
| 37 |
+
|
| 38 |
+
def test_manual_tag_input(self, rag_manager, sample_documents):
|
| 39 |
+
"""Test manual tag input during upload"""
|
| 40 |
+
doc_data = sample_documents["emergency"]
|
| 41 |
+
user_tags = ["custom-tag-1", "custom-tag-2"]
|
| 42 |
+
|
| 43 |
+
chunker = FlatTagChunker()
|
| 44 |
+
chunks = chunker.chunk_document(
|
| 45 |
+
doc_data["content"],
|
| 46 |
+
language=doc_data["language"],
|
| 47 |
+
user_tags=user_tags
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Check that user tags are included
|
| 51 |
+
if chunks:
|
| 52 |
+
chunk_tags = chunks[0].metadata.get('tags', [])
|
| 53 |
+
if isinstance(chunk_tags, str):
|
| 54 |
+
chunk_tags = [t.strip() for t in chunk_tags.split(',')]
|
| 55 |
+
|
| 56 |
+
# User tags should be present (may be lowercased)
|
| 57 |
+
chunk_tags_lower = [t.lower() for t in chunk_tags]
|
| 58 |
+
assert any(ut.lower() in chunk_tags_lower for ut in user_tags)
|
| 59 |
+
|
| 60 |
+
def test_search_workflow(self, populated_rag_manager):
|
| 61 |
+
"""Test search workflow returns results"""
|
| 62 |
+
query = "What are emergency procedures?"
|
| 63 |
+
|
| 64 |
+
# Test all pipelines
|
| 65 |
+
pipelines = {
|
| 66 |
+
"base_rag": populated_rag_manager.base_rag,
|
| 67 |
+
"tag_filter_rag": populated_rag_manager.tag_filter_rag,
|
| 68 |
+
"hybrid_rag": populated_rag_manager.hybrid_rag,
|
| 69 |
+
"hybrid_rerank_rag": populated_rag_manager.hybrid_rerank_rag
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
for pipeline_name, pipeline in pipelines.items():
|
| 73 |
+
if pipeline_name == "tag_filter_rag":
|
| 74 |
+
result = pipeline.retrieve(query, k=3, tags=["emergency"], tag_operator="OR")
|
| 75 |
+
elif pipeline_name in ["hybrid_rag", "hybrid_rerank_rag"]:
|
| 76 |
+
result = pipeline.retrieve(query, k=3, tags=["emergency"], vector_weight=0.7, tag_weight=0.3)
|
| 77 |
+
else:
|
| 78 |
+
result = pipeline.retrieve(query, k=3)
|
| 79 |
+
|
| 80 |
+
assert result is not None
|
| 81 |
+
assert result.latency > 0
|
| 82 |
+
|
| 83 |
+
def test_evaluation_workflow(self, evaluator, populated_rag_manager, sample_queries):
|
| 84 |
+
"""Test evaluation workflow produces results"""
|
| 85 |
+
df, summary, results = evaluator.batch_evaluate(
|
| 86 |
+
sample_queries[:2], # Use 2 queries for speed
|
| 87 |
+
output_file=None,
|
| 88 |
+
pipelines=['base_rag']
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
assert df is not None
|
| 92 |
+
assert len(df) > 0
|
| 93 |
+
assert summary is not None
|
| 94 |
+
assert results is not None
|
| 95 |
+
|
| 96 |
+
def test_session_persistence(self, session_manager):
|
| 97 |
+
"""Test session creation and retrieval"""
|
| 98 |
+
# Create session
|
| 99 |
+
session = session_manager.create_session(user_id="test_user")
|
| 100 |
+
session_id = session.session_id
|
| 101 |
+
|
| 102 |
+
# Retrieve session
|
| 103 |
+
retrieved = session_manager.get_session(session_id)
|
| 104 |
+
|
| 105 |
+
assert retrieved is not None
|
| 106 |
+
assert retrieved.session_id == session_id
|
| 107 |
+
assert retrieved.user_id == "test_user"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class TestUserInterface:
|
| 111 |
+
"""Test user interface components"""
|
| 112 |
+
|
| 113 |
+
def test_tag_visualization_format(self, rag_manager, sample_documents):
|
| 114 |
+
"""Test tag visualization format"""
|
| 115 |
+
doc_data = sample_documents["emergency"]
|
| 116 |
+
|
| 117 |
+
chunker = FlatTagChunker()
|
| 118 |
+
chunks = chunker.chunk_document(
|
| 119 |
+
doc_data["content"],
|
| 120 |
+
language=doc_data["language"],
|
| 121 |
+
user_tags=None
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Extract tags for visualization
|
| 125 |
+
all_tags = []
|
| 126 |
+
for chunk in chunks:
|
| 127 |
+
tags = chunk.metadata.get('tags', [])
|
| 128 |
+
if isinstance(tags, str):
|
| 129 |
+
tags = [t.strip() for t in tags.split(',')]
|
| 130 |
+
all_tags.extend(tags)
|
| 131 |
+
|
| 132 |
+
# Tag visualization should be readable
|
| 133 |
+
assert len(all_tags) > 0
|
| 134 |
+
assert all(isinstance(tag, str) for tag in all_tags)
|
| 135 |
+
|
| 136 |
+
def test_document_count_display(self, rag_manager, sample_documents):
|
| 137 |
+
"""Test document count accuracy"""
|
| 138 |
+
# Index documents
|
| 139 |
+
all_chunks = []
|
| 140 |
+
for doc_data in sample_documents.values():
|
| 141 |
+
chunker = FlatTagChunker()
|
| 142 |
+
chunks = chunker.chunk_document(
|
| 143 |
+
doc_data["content"],
|
| 144 |
+
language=doc_data["language"],
|
| 145 |
+
user_tags=None
|
| 146 |
+
)
|
| 147 |
+
all_chunks.extend(chunks)
|
| 148 |
+
|
| 149 |
+
if all_chunks:
|
| 150 |
+
rag_manager.vector_store.add_documents("documents", all_chunks)
|
| 151 |
+
stats = rag_manager.vector_store.get_collection_stats("documents")
|
| 152 |
+
|
| 153 |
+
# Document count should match number of unique documents
|
| 154 |
+
assert stats["document_count"] > 0
|
| 155 |
+
assert stats["document_count"] <= len(sample_documents) # Should be <= number of documents
|
| 156 |
+
|