BaHung
commited on
Commit
·
54b2662
0
Parent(s):
Clean repo
Browse files- .gitattributes +61 -0
- .gitignore +174 -0
- Dockerfile +0 -0
- README.md +0 -0
- api/__init__.py +0 -0
- api/app.py +0 -0
- api/routes/chat.py +0 -0
- api/routes/health.py +0 -0
- api/schemas.py +0 -0
- config/__init__.py +0 -0
- config/base.py +4 -0
- config/finetune_config.yaml +0 -0
- config/parse_config.yaml +3 -0
- config/rag_config.yaml +0 -0
- core/__init__.py +0 -0
- core/embeddings/__init__.py +0 -0
- core/embeddings/embedding_model.py +0 -0
- core/embeddings/vector_store.py +0 -0
- core/fine_tune/__init__.py +0 -0
- core/fine_tune/data_prep.py +0 -0
- core/fine_tune/evaluator.py +0 -0
- core/fine_tune/trainer.py +0 -0
- core/hash_file/__init__.py +0 -0
- core/hash_file/hash_data_goc.py +109 -0
- core/hash_file/hash_file.py +118 -0
- core/preprocessing/__init__.py +0 -0
- core/preprocessing/chunker.py +0 -0
- core/preprocessing/docling_processor.py +137 -0
- core/preprocessing/pdf_parser.py +75 -0
- requirements.txt +5 -0
- test/parse_data_hash_test.py +118 -0
- utils/__init__.py +0 -0
- utils/helpers.py +136 -0
- utils/logger.py +46 -0
- utils/metrics.py +102 -0
.gitattributes
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.mds filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
# Audio files - uncompressed
|
| 39 |
+
*.pcm filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.sam filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.raw filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
# Audio files - compressed
|
| 43 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
# Image files - uncompressed
|
| 49 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
# Image files - compressed
|
| 54 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
# Video files - compressed
|
| 58 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
data/files/*.pdf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
*.manifest
|
| 32 |
+
*.spec
|
| 33 |
+
|
| 34 |
+
# Installer logs
|
| 35 |
+
pip-log.txt
|
| 36 |
+
pip-delete-this-directory.txt
|
| 37 |
+
|
| 38 |
+
# Unit test / coverage reports
|
| 39 |
+
htmlcov/
|
| 40 |
+
.tox/
|
| 41 |
+
.nox/
|
| 42 |
+
.coverage
|
| 43 |
+
.coverage.*
|
| 44 |
+
.cache
|
| 45 |
+
nosetests.xml
|
| 46 |
+
coverage.xml
|
| 47 |
+
*.cover
|
| 48 |
+
*.py,cover
|
| 49 |
+
.hypothesis/
|
| 50 |
+
.pytest_cache/
|
| 51 |
+
|
| 52 |
+
# Translations
|
| 53 |
+
*.mo
|
| 54 |
+
*.pot
|
| 55 |
+
|
| 56 |
+
# Django stuff:
|
| 57 |
+
*.log
|
| 58 |
+
local_settings.py
|
| 59 |
+
db.sqlite3
|
| 60 |
+
db.sqlite3-journal
|
| 61 |
+
|
| 62 |
+
# Flask stuff:
|
| 63 |
+
instance/
|
| 64 |
+
.webassets-cache
|
| 65 |
+
|
| 66 |
+
# Scrapy stuff:
|
| 67 |
+
.scrapy
|
| 68 |
+
|
| 69 |
+
# Sphinx documentation
|
| 70 |
+
docs/_build/
|
| 71 |
+
|
| 72 |
+
# PyBuilder
|
| 73 |
+
target/
|
| 74 |
+
|
| 75 |
+
# Jupyter Notebook
|
| 76 |
+
.ipynb_checkpoints
|
| 77 |
+
|
| 78 |
+
# IPython
|
| 79 |
+
profile_default/
|
| 80 |
+
ipython_config.py
|
| 81 |
+
|
| 82 |
+
# pyenv
|
| 83 |
+
.python-version
|
| 84 |
+
|
| 85 |
+
__pypackages__/
|
| 86 |
+
|
| 87 |
+
# Celery stuff
|
| 88 |
+
celerybeat-schedule
|
| 89 |
+
celerybeat.pid
|
| 90 |
+
|
| 91 |
+
# SageMath parsed files
|
| 92 |
+
*.sage.py
|
| 93 |
+
|
| 94 |
+
# Environment variables
|
| 95 |
+
.env
|
| 96 |
+
.venv
|
| 97 |
+
env/
|
| 98 |
+
venv/
|
| 99 |
+
ENV/
|
| 100 |
+
env.bak/
|
| 101 |
+
venv.bak/
|
| 102 |
+
|
| 103 |
+
# Spyder project settings
|
| 104 |
+
.spyderproject
|
| 105 |
+
.spyproject
|
| 106 |
+
|
| 107 |
+
# Rope project settings
|
| 108 |
+
.ropeproject
|
| 109 |
+
|
| 110 |
+
# mkdocs documentation
|
| 111 |
+
/site
|
| 112 |
+
|
| 113 |
+
# mypy
|
| 114 |
+
.mypy_cache/
|
| 115 |
+
.dmypy.json
|
| 116 |
+
dmypy.json
|
| 117 |
+
|
| 118 |
+
# Pyre type checker
|
| 119 |
+
.pyre/
|
| 120 |
+
|
| 121 |
+
# IDE files
|
| 122 |
+
.vscode/
|
| 123 |
+
.idea/
|
| 124 |
+
*.swp
|
| 125 |
+
*.swo
|
| 126 |
+
*~
|
| 127 |
+
|
| 128 |
+
# OS generated files
|
| 129 |
+
.DS_Store
|
| 130 |
+
.DS_Store?
|
| 131 |
+
._*
|
| 132 |
+
.Spotlight-V100
|
| 133 |
+
.Trashes
|
| 134 |
+
ehthumbs.db
|
| 135 |
+
Thumbs.db
|
| 136 |
+
|
| 137 |
+
# Project specific
|
| 138 |
+
marker_out/
|
| 139 |
+
test_input/
|
| 140 |
+
test_output/
|
| 141 |
+
chunks/
|
| 142 |
+
chunking_analysis/
|
| 143 |
+
test_pipeline_*/
|
| 144 |
+
|
| 145 |
+
# Model cache
|
| 146 |
+
.cache/
|
| 147 |
+
models/
|
| 148 |
+
*.safetensors
|
| 149 |
+
*.bin
|
| 150 |
+
*.onnx
|
| 151 |
+
|
| 152 |
+
# Temporary files
|
| 153 |
+
*.tmp
|
| 154 |
+
*.temp
|
| 155 |
+
temp/
|
| 156 |
+
tmp/
|
| 157 |
+
|
| 158 |
+
# Log files
|
| 159 |
+
*.log
|
| 160 |
+
logs/
|
| 161 |
+
|
| 162 |
+
# API keys and sensitive data
|
| 163 |
+
config.json
|
| 164 |
+
secrets.json
|
| 165 |
+
api_keys.txt
|
| 166 |
+
|
| 167 |
+
# Backup files
|
| 168 |
+
*.bak
|
| 169 |
+
*.backup
|
| 170 |
+
*~
|
| 171 |
+
__pycache__/
|
| 172 |
+
|
| 173 |
+
/model/
|
| 174 |
+
/data/
|
Dockerfile
ADDED
|
File without changes
|
README.md
ADDED
|
File without changes
|
api/__init__.py
ADDED
|
File without changes
|
api/app.py
ADDED
|
File without changes
|
api/routes/chat.py
ADDED
|
File without changes
|
api/routes/health.py
ADDED
|
File without changes
|
api/schemas.py
ADDED
|
File without changes
|
config/__init__.py
ADDED
|
File without changes
|
config/base.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
|
| 3 |
+
# Login using e.g. `huggingface-cli login` to access this dataset
|
| 4 |
+
ds = load_dataset("hungnha/Do_An_Dataset")
|
config/finetune_config.yaml
ADDED
|
File without changes
|
config/parse_config.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
INPUT_PATH:''
|
| 2 |
+
OUTPUT_PATH:''
|
| 3 |
+
|
config/rag_config.yaml
ADDED
|
File without changes
|
core/__init__.py
ADDED
|
File without changes
|
core/embeddings/__init__.py
ADDED
|
File without changes
|
core/embeddings/embedding_model.py
ADDED
|
File without changes
|
core/embeddings/vector_store.py
ADDED
|
File without changes
|
core/fine_tune/__init__.py
ADDED
|
File without changes
|
core/fine_tune/data_prep.py
ADDED
|
File without changes
|
core/fine_tune/evaluator.py
ADDED
|
File without changes
|
core/fine_tune/trainer.py
ADDED
|
File without changes
|
core/hash_file/__init__.py
ADDED
|
File without changes
|
core/hash_file/hash_data_goc.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Setup path
|
| 7 |
+
current_file = Path(__file__).resolve()
|
| 8 |
+
project_root = current_file.parent.parent.parent
|
| 9 |
+
if str(project_root) not in sys.path:
|
| 10 |
+
sys.path.insert(0, str(project_root))
|
| 11 |
+
|
| 12 |
+
from typing import cast, Dict, Any
|
| 13 |
+
from datasets import load_dataset, Dataset
|
| 14 |
+
from core.hash_file.hash_file import HashProcessor
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
# Khởi tạo
|
| 18 |
+
data_dir = project_root / "data"
|
| 19 |
+
files_dir = data_dir / "files"
|
| 20 |
+
files_dir.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
hash_processor = HashProcessor(verbose=False)
|
| 23 |
+
hash_file_path = data_dir / "hash_data_goc_index.json"
|
| 24 |
+
|
| 25 |
+
# Load existing hash index
|
| 26 |
+
existing_hashes = {}
|
| 27 |
+
if hash_file_path.exists():
|
| 28 |
+
with open(hash_file_path, 'r', encoding='utf-8') as f:
|
| 29 |
+
data = json.load(f)
|
| 30 |
+
existing_hashes = {item['index']: item['hash'] for item in data.get('train', [])}
|
| 31 |
+
print(f"📂 Đã tải {len(existing_hashes)} hash từ index cũ")
|
| 32 |
+
|
| 33 |
+
# Load dataset
|
| 34 |
+
print("📥 Đang tải dataset từ Hugging Face...")
|
| 35 |
+
dataset = load_dataset("hungnha/Do_An_Dataset")
|
| 36 |
+
train_dataset = cast(Dataset, dataset['train'])
|
| 37 |
+
print(f"✅ Đã tải {len(train_dataset)} files\n")
|
| 38 |
+
|
| 39 |
+
# Xử lý từng file
|
| 40 |
+
hash_results = []
|
| 41 |
+
skipped = 0
|
| 42 |
+
processed = 0
|
| 43 |
+
|
| 44 |
+
for idx, sample in enumerate(train_dataset):
|
| 45 |
+
sample = cast(Dict[str, Any], sample)
|
| 46 |
+
filename = f"train_{idx:04d}.pdf"
|
| 47 |
+
filepath = files_dir / filename
|
| 48 |
+
|
| 49 |
+
# Kiểm tra file đã tồn tại chưa
|
| 50 |
+
if filepath.exists() and idx in existing_hashes:
|
| 51 |
+
# Verify hash
|
| 52 |
+
current_hash = hash_processor.get_file_hash(str(filepath))
|
| 53 |
+
if current_hash == existing_hashes[idx]:
|
| 54 |
+
hash_results.append({
|
| 55 |
+
'filename': filename,
|
| 56 |
+
'hash': current_hash,
|
| 57 |
+
'index': idx
|
| 58 |
+
})
|
| 59 |
+
skipped += 1
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
# Lấy PDF object
|
| 64 |
+
pdf_obj = sample['pdf']
|
| 65 |
+
|
| 66 |
+
# Xử lý dữ liệu từ datasets (thường là dict chứa bytes hoặc bytes trực tiếp)
|
| 67 |
+
if isinstance(pdf_obj, dict) and 'bytes' in pdf_obj:
|
| 68 |
+
pdf_bytes = pdf_obj['bytes']
|
| 69 |
+
elif isinstance(pdf_obj, bytes):
|
| 70 |
+
pdf_bytes = pdf_obj
|
| 71 |
+
else:
|
| 72 |
+
print(f"⚠️ Bỏ qua file {idx} - định dạng dữ liệu không hỗ trợ: {type(pdf_obj)}")
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
# Lưu file
|
| 76 |
+
filepath.write_bytes(pdf_bytes)
|
| 77 |
+
|
| 78 |
+
# Tính hash
|
| 79 |
+
file_hash = hash_processor.get_file_hash(str(filepath))
|
| 80 |
+
if file_hash is None:
|
| 81 |
+
print(f"❌ Lỗi tính hash cho file {idx}")
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
hash_results.append({
|
| 85 |
+
'filename': filename,
|
| 86 |
+
'hash': file_hash,
|
| 87 |
+
'index': idx
|
| 88 |
+
})
|
| 89 |
+
processed += 1
|
| 90 |
+
|
| 91 |
+
if (idx + 1) % 10 == 0:
|
| 92 |
+
print(f"📄 Đã xử lý {idx + 1}/{len(train_dataset)} files (mới: {processed}, bỏ qua: {skipped})")
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"❌ Lỗi xử lý file {idx}: {e}")
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
# Lưu hash index
|
| 99 |
+
hash_file_path.write_text(json.dumps({'train': hash_results}, indent=2, ensure_ascii=False))
|
| 100 |
+
|
| 101 |
+
print(f"\n✅ Hoàn thành!")
|
| 102 |
+
print(f" - Đã xử lý mới: {processed} files")
|
| 103 |
+
print(f" - Đã bỏ qua: {skipped} files")
|
| 104 |
+
print(f" - Tổng cộng: {len(hash_results)} files")
|
| 105 |
+
print(f"📁 Thư mục files: {files_dir}")
|
| 106 |
+
print(f"📄 Hash index: {hash_file_path}")
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
main()
|
core/hash_file/hash_file.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
# Constants
|
| 11 |
+
CHUNK_SIZE = 8192 # 8KB chunks for reading files
|
| 12 |
+
DEFAULT_FILE_EXTENSION = '.pdf'
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level=logging.INFO,
|
| 17 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
class HashProcessor:
|
| 21 |
+
"""Simplified HashProcessor for RAG system - only core functionality."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, verbose: bool = True):
|
| 24 |
+
self.verbose = verbose
|
| 25 |
+
self.logger = logging.getLogger(__name__)
|
| 26 |
+
if not verbose:
|
| 27 |
+
self.logger.setLevel(logging.WARNING)
|
| 28 |
+
|
| 29 |
+
def get_file_hash(self, path: str) -> Optional[str]:
|
| 30 |
+
"""Calculate SHA256 hash of file."""
|
| 31 |
+
h = hashlib.sha256()
|
| 32 |
+
try:
|
| 33 |
+
with open(path, "rb") as f:
|
| 34 |
+
while chunk := f.read(CHUNK_SIZE):
|
| 35 |
+
h.update(chunk)
|
| 36 |
+
return h.hexdigest()
|
| 37 |
+
except (IOError, OSError) as e:
|
| 38 |
+
self.logger.error(f"Lỗi khi đọc file {path}: {e}")
|
| 39 |
+
return None
|
| 40 |
+
except Exception as e:
|
| 41 |
+
self.logger.error(f"Lỗi không xác định khi xử lý file {path}: {e}")
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
def scan_files_for_hash(
|
| 45 |
+
self,
|
| 46 |
+
source_dir: str,
|
| 47 |
+
file_extension: str = DEFAULT_FILE_EXTENSION
|
| 48 |
+
) -> Dict[str, List[Dict[str, str]]]:
|
| 49 |
+
"""Scan directory and calculate hash for each file."""
|
| 50 |
+
if not os.path.exists(source_dir):
|
| 51 |
+
raise FileNotFoundError(f"Thư mục không tồn tại: {source_dir}")
|
| 52 |
+
|
| 53 |
+
if not os.path.isdir(source_dir):
|
| 54 |
+
raise NotADirectoryError(f"Đường dẫn không phải là thư mục: {source_dir}")
|
| 55 |
+
|
| 56 |
+
hash_to_files = defaultdict(list)
|
| 57 |
+
|
| 58 |
+
self.logger.info(f"Đang quét file trong thư mục: {source_dir}")
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
files = [f for f in os.listdir(source_dir)
|
| 62 |
+
if f.lower().endswith(file_extension.lower())]
|
| 63 |
+
|
| 64 |
+
for filename in files:
|
| 65 |
+
file_path = os.path.join(source_dir, filename)
|
| 66 |
+
|
| 67 |
+
if not os.path.isfile(file_path):
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
self.logger.info(f"Đang tính hash cho: {filename}")
|
| 71 |
+
|
| 72 |
+
file_hash = self.get_file_hash(file_path)
|
| 73 |
+
if file_hash:
|
| 74 |
+
hash_to_files[file_hash].append({
|
| 75 |
+
'filename': filename,
|
| 76 |
+
'path': file_path,
|
| 77 |
+
'size': os.path.getsize(file_path)
|
| 78 |
+
})
|
| 79 |
+
except PermissionError as e:
|
| 80 |
+
self.logger.error(f"Không có quyền truy cập thư mục {source_dir}: {e}")
|
| 81 |
+
raise
|
| 82 |
+
|
| 83 |
+
return hash_to_files
|
| 84 |
+
|
| 85 |
+
def load_processed_index(self, index_file: str) -> Dict:
|
| 86 |
+
"""Load processed index from JSON file."""
|
| 87 |
+
if os.path.exists(index_file):
|
| 88 |
+
try:
|
| 89 |
+
with open(index_file, "r", encoding="utf-8") as f:
|
| 90 |
+
return json.load(f)
|
| 91 |
+
except json.JSONDecodeError as e:
|
| 92 |
+
self.logger.error(f"Lỗi đọc file index {index_file}: {e}")
|
| 93 |
+
return {}
|
| 94 |
+
except Exception as e:
|
| 95 |
+
self.logger.error(f"Lỗi không xác định khi đọc index: {e}")
|
| 96 |
+
return {}
|
| 97 |
+
return {}
|
| 98 |
+
|
| 99 |
+
def save_processed_index(self, index_file: str, processed_hashes: Dict) -> None:
|
| 100 |
+
"""Save processed index to JSON file."""
|
| 101 |
+
try:
|
| 102 |
+
# Tạo thư mục nếu chưa tồn tại
|
| 103 |
+
os.makedirs(os.path.dirname(index_file), exist_ok=True)
|
| 104 |
+
|
| 105 |
+
with open(index_file, "w", encoding="utf-8") as f:
|
| 106 |
+
json.dump(processed_hashes, f, indent=2, ensure_ascii=False)
|
| 107 |
+
self.logger.info(f"Đã lưu index file: {index_file}")
|
| 108 |
+
except Exception as e:
|
| 109 |
+
self.logger.error(f"Lỗi khi lưu index file {index_file}: {e}")
|
| 110 |
+
|
| 111 |
+
def get_current_timestamp(self) -> str:
|
| 112 |
+
"""Get current timestamp in ISO format."""
|
| 113 |
+
return datetime.now().isoformat()
|
| 114 |
+
|
| 115 |
+
def get_string_hash(self, text: str) -> str:
|
| 116 |
+
"""Calculate SHA256 hash of string."""
|
| 117 |
+
return hashlib.sha256(text.encode('utf-8')).hexdigest()
|
| 118 |
+
|
core/preprocessing/__init__.py
ADDED
|
File without changes
|
core/preprocessing/chunker.py
ADDED
|
File without changes
|
core/preprocessing/docling_processor.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import signal
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
# Import dependencies
|
| 10 |
+
from core.hash_file.hash_file import HashProcessor
|
| 11 |
+
|
| 12 |
+
from docling.document_converter import DocumentConverter, FormatOption
|
| 13 |
+
from docling.datamodel.base_models import InputFormat
|
| 14 |
+
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
| 15 |
+
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
| 16 |
+
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DoclingProcessor:
|
| 21 |
+
|
| 22 |
+
def __init__(self, output_dir: str, use_ocr: bool = False, timeout: int = 300):
|
| 23 |
+
self.output_dir = output_dir
|
| 24 |
+
self.use_ocr = use_ocr
|
| 25 |
+
self.timeout = timeout
|
| 26 |
+
self.hash_processor = HashProcessor(verbose=False)
|
| 27 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
# Cache system - lưu index ở ngoài output_dir
|
| 30 |
+
data_dir = Path(output_dir).parent
|
| 31 |
+
self.index_file = str(data_dir / "hash_docling_index.json")
|
| 32 |
+
self.parsed_docs = self.hash_processor.load_processed_index(self.index_file)
|
| 33 |
+
|
| 34 |
+
# Cấu hình OCR settings
|
| 35 |
+
if not use_ocr:
|
| 36 |
+
# Tạo PDF pipeline options với OCR tắt
|
| 37 |
+
pdf_pipeline_options = PdfPipelineOptions(
|
| 38 |
+
do_ocr=False, # Tắt OCR hoàn toàn
|
| 39 |
+
do_table_structure=True, # Vẫn giữ table structure
|
| 40 |
+
do_picture_classification=False,
|
| 41 |
+
do_picture_description=False
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Tạo format options cho PDF
|
| 45 |
+
format_options = {
|
| 46 |
+
InputFormat.PDF: FormatOption(
|
| 47 |
+
backend=PyPdfiumDocumentBackend,
|
| 48 |
+
pipeline_cls=StandardPdfPipeline,
|
| 49 |
+
pipeline_options=pdf_pipeline_options
|
| 50 |
+
)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
self.converter = DocumentConverter(format_options=format_options)
|
| 54 |
+
print("🔧 OCR completely disabled for docling")
|
| 55 |
+
else:
|
| 56 |
+
# Sử dụng converter mặc định với OCR enabled
|
| 57 |
+
self.converter = DocumentConverter()
|
| 58 |
+
print("🔧 OCR enabled for docling")
|
| 59 |
+
|
| 60 |
+
def parse_document(self, file_path: str) -> Optional[Dict]:
|
| 61 |
+
"""Parse single document - có cache system!"""
|
| 62 |
+
if not os.path.exists(file_path):
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
filename = os.path.basename(file_path)
|
| 66 |
+
file_hash = self.hash_processor.get_file_hash(file_path)
|
| 67 |
+
|
| 68 |
+
# Kiểm tra cache trước
|
| 69 |
+
if file_hash in self.parsed_docs:
|
| 70 |
+
cached_info = self.parsed_docs[file_hash]
|
| 71 |
+
output_path = os.path.join(self.output_dir, cached_info['output_file'])
|
| 72 |
+
if os.path.exists(output_path):
|
| 73 |
+
print(f"⏭️ Already parsed: {filename}")
|
| 74 |
+
with open(output_path, 'r', encoding='utf-8') as f:
|
| 75 |
+
return json.load(f)
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
print(f"🔄 Processing: {filename}...")
|
| 79 |
+
|
| 80 |
+
# Set timeout alarm
|
| 81 |
+
signal.signal(signal.SIGALRM, lambda s, f: (_ for _ in ()).throw(TimeoutError("Processing timeout")))
|
| 82 |
+
signal.alarm(self.timeout)
|
| 83 |
+
|
| 84 |
+
result = self.converter.convert(file_path)
|
| 85 |
+
docling_json = result.document.export_to_dict()
|
| 86 |
+
|
| 87 |
+
# Cancel timeout
|
| 88 |
+
signal.alarm(0)
|
| 89 |
+
|
| 90 |
+
except TimeoutError:
|
| 91 |
+
print(f"⏰ Timeout processing {filename} (>{self.timeout}s)")
|
| 92 |
+
signal.alarm(0)
|
| 93 |
+
return None
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"❌ Failed to parse {filename}: {e}")
|
| 96 |
+
signal.alarm(0)
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
# Lưu file JSON với tên có hash để tránh trùng lặp
|
| 100 |
+
output_file = f"{Path(filename).stem}_{file_hash[:8]}.json"
|
| 101 |
+
output_path = os.path.join(self.output_dir, output_file)
|
| 102 |
+
|
| 103 |
+
# Lưu output mặc định của Docling - minified như Docling gốc
|
| 104 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 105 |
+
json.dump(docling_json, f, ensure_ascii=False)
|
| 106 |
+
|
| 107 |
+
# Cập nhật cache
|
| 108 |
+
self.parsed_docs[file_hash] = {
|
| 109 |
+
"filename": filename,
|
| 110 |
+
"output_file": output_file,
|
| 111 |
+
"parsed_date": datetime.now().isoformat()
|
| 112 |
+
}
|
| 113 |
+
self.hash_processor.save_processed_index(self.index_file, self.parsed_docs)
|
| 114 |
+
|
| 115 |
+
print(f"✓ Parsed: {filename}")
|
| 116 |
+
return docling_json
|
| 117 |
+
|
| 118 |
+
def parse_directory(self, source_dir: str) -> Dict:
|
| 119 |
+
"""Parse all PDFs in directory - tận dụng HashProcessor"""
|
| 120 |
+
print(f"Parsing PDFs in: {source_dir}")
|
| 121 |
+
|
| 122 |
+
# Tận dụng HashProcessor để scan files
|
| 123 |
+
hash_to_files = self.hash_processor.scan_files_for_hash(source_dir, '.pdf')
|
| 124 |
+
|
| 125 |
+
results = {"total": 0, "parsed": 0, "errors": 0}
|
| 126 |
+
|
| 127 |
+
for file_hash, file_list in hash_to_files.items():
|
| 128 |
+
for file_info in file_list:
|
| 129 |
+
results["total"] += 1
|
| 130 |
+
result = self.parse_document(file_info['path'])
|
| 131 |
+
if result:
|
| 132 |
+
results["parsed"] += 1
|
| 133 |
+
else:
|
| 134 |
+
results["errors"] += 1
|
| 135 |
+
|
| 136 |
+
print(f"Summary: {results['parsed']}/{results['total']} files parsed")
|
| 137 |
+
return results
|
core/preprocessing/pdf_parser.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import traceback
|
| 4 |
+
import warnings
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from core.preprocessing.docling_processor import DoclingProcessor
|
| 9 |
+
|
| 10 |
+
# Tắt cảnh báo pin_memory từ docling/PyTorch
|
| 11 |
+
warnings.filterwarnings("ignore", message=".*pin_memory.*")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_project_paths():
|
| 15 |
+
# Lấy từ data/files
|
| 16 |
+
source_dir = Path("data/files").resolve()
|
| 17 |
+
output_dir = Path("data/docling_output").resolve()
|
| 18 |
+
return str(source_dir), str(output_dir)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main(source_dir=None, output_dir=None, use_ocr=False, timeout=300):
|
| 22 |
+
"""Parse PDF documents."""
|
| 23 |
+
|
| 24 |
+
# Auto-detect paths nếu không được cung cấp
|
| 25 |
+
if source_dir is None or output_dir is None:
|
| 26 |
+
auto_source, auto_output = get_project_paths()
|
| 27 |
+
source_dir = source_dir or auto_source
|
| 28 |
+
output_dir = output_dir or auto_output
|
| 29 |
+
|
| 30 |
+
# Kiểm tra source directory
|
| 31 |
+
if not os.path.exists(source_dir):
|
| 32 |
+
print(f"❌ Source not found: {source_dir}")
|
| 33 |
+
print(f"\n💡 Solution:")
|
| 34 |
+
print(f" 1. Run hash_data_goc.py first to download PDFs")
|
| 35 |
+
print(f" 2. Or specify path: python parse_data_hash.py --source /path/to/pdfs")
|
| 36 |
+
return 1
|
| 37 |
+
|
| 38 |
+
print(f"📂 Source: {source_dir}")
|
| 39 |
+
print(f"📁 Output: {output_dir}\n")
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
processor = DoclingProcessor(
|
| 43 |
+
output_dir=output_dir,
|
| 44 |
+
use_ocr=use_ocr,
|
| 45 |
+
timeout=timeout
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
results = processor.parse_directory(source_dir)
|
| 49 |
+
|
| 50 |
+
print(f"\n📊 Total: {results['total']} docs | "
|
| 51 |
+
f"Parsed: {results['parsed']} | Errors: {results['errors']}\n")
|
| 52 |
+
|
| 53 |
+
return 0
|
| 54 |
+
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"\n❌ Error: {e}")
|
| 57 |
+
traceback.print_exc()
|
| 58 |
+
return 1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
import argparse
|
| 63 |
+
parser = argparse.ArgumentParser(description="Parse PDFs with Docling")
|
| 64 |
+
parser.add_argument("--source", help="Source directory with PDFs")
|
| 65 |
+
parser.add_argument("--output", help="Output directory for results")
|
| 66 |
+
parser.add_argument("--ocr", action="store_true", help="Enable OCR")
|
| 67 |
+
parser.add_argument("--timeout", type=int, default=300, help="Timeout per file in seconds (default: 300)")
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
exit(main(
|
| 71 |
+
source_dir=args.source,
|
| 72 |
+
output_dir=args.output,
|
| 73 |
+
use_ocr=args.ocr,
|
| 74 |
+
timeout=args.timeout
|
| 75 |
+
))
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#unsloth
|
| 2 |
+
#langchain
|
| 3 |
+
docling
|
| 4 |
+
datasets
|
| 5 |
+
pdfplumber
|
test/parse_data_hash_test.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
# Ensure project root is on sys.path so `core` and `config` can be imported
|
| 6 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 7 |
+
if _PROJECT_ROOT not in sys.path:
|
| 8 |
+
sys.path.insert(0, _PROJECT_ROOT)
|
| 9 |
+
|
| 10 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 11 |
+
from core.preprocessing.docling_processor import DoclingProcessor
|
| 12 |
+
from config.base import ds
|
| 13 |
+
|
| 14 |
+
REPO_ID = "hungnha/Do_An_Dataset"
|
| 15 |
+
|
| 16 |
+
def _extract_pdf_path_from_example(example):
|
| 17 |
+
# Tìm path PDF trong example (ưu tiên các giá trị string kết thúc .pdf và tồn tại trên máy/cached)
|
| 18 |
+
if isinstance(example, dict):
|
| 19 |
+
for value in example.values():
|
| 20 |
+
if isinstance(value, str) and value.lower().endswith('.pdf') and os.path.exists(value):
|
| 21 |
+
return value
|
| 22 |
+
# Không tìm thấy
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
def _download_random_pdf_from_hub(repo_id: str) -> str:
|
| 26 |
+
api = HfApi()
|
| 27 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
|
| 28 |
+
pdf_files = [f for f in files if f.lower().endswith('.pdf')]
|
| 29 |
+
if not pdf_files:
|
| 30 |
+
return None
|
| 31 |
+
chosen = random.choice(pdf_files)
|
| 32 |
+
# Tải về cache local và trả về đường dẫn
|
| 33 |
+
try:
|
| 34 |
+
local_path = hf_hub_download(repo_id=repo_id, filename=chosen, repo_type="dataset")
|
| 35 |
+
return local_path
|
| 36 |
+
except Exception:
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
def main(output_dir=None, use_ocr=False):
|
| 40 |
+
"""Parse PDF documents - test mode chỉ chạy 1 file random."""
|
| 41 |
+
|
| 42 |
+
# Auto-detect output path (dataset đọc từ cache HF, không dùng source_dir local)
|
| 43 |
+
if output_dir is None:
|
| 44 |
+
output_dir = "core/data"
|
| 45 |
+
|
| 46 |
+
# Lấy split chính (mặc định 'train' nếu có)
|
| 47 |
+
split = 'train' if 'train' in ds else list(ds.keys())[0]
|
| 48 |
+
dataset_split = ds[split]
|
| 49 |
+
print(f"📚 Using split: {split} (n={len(dataset_split)})")
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
# Ưu tiên: tải ngẫu nhiên 1 PDF trực tiếp từ repo dataset trên Hugging Face
|
| 53 |
+
file_path = _download_random_pdf_from_hub(REPO_ID)
|
| 54 |
+
if file_path is None:
|
| 55 |
+
# Fallback: thử lấy từ example đã cache (nếu dataset lưu sẵn đường dẫn local)
|
| 56 |
+
if len(dataset_split) == 0:
|
| 57 |
+
print("❌ Dataset split is empty")
|
| 58 |
+
return 1
|
| 59 |
+
attempts = 0
|
| 60 |
+
while attempts < 32 and file_path is None:
|
| 61 |
+
idx = random.randint(0, len(dataset_split) - 1)
|
| 62 |
+
example = dataset_split[idx]
|
| 63 |
+
candidate = _extract_pdf_path_from_example(example)
|
| 64 |
+
if candidate is not None:
|
| 65 |
+
file_path = candidate
|
| 66 |
+
break
|
| 67 |
+
attempts += 1
|
| 68 |
+
if file_path is None:
|
| 69 |
+
print("❌ Could not locate any PDF (hub or cache)")
|
| 70 |
+
return 1
|
| 71 |
+
|
| 72 |
+
random_file = os.path.basename(file_path)
|
| 73 |
+
|
| 74 |
+
print(f"🎯 Testing with: {random_file}\n")
|
| 75 |
+
|
| 76 |
+
# Khởi tạo processor (vẫn dùng cache system)
|
| 77 |
+
processor = DoclingProcessor(
|
| 78 |
+
output_dir=output_dir,
|
| 79 |
+
use_ocr=use_ocr,
|
| 80 |
+
timeout=300
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Parse 1 file
|
| 84 |
+
result = processor.parse_document(file_path)
|
| 85 |
+
|
| 86 |
+
if result:
|
| 87 |
+
print(f"\n✅ Test successful!")
|
| 88 |
+
print(f"📊 Parsed: {random_file}")
|
| 89 |
+
# Kiểm tra output file
|
| 90 |
+
random_stem = os.path.splitext(random_file)[0]
|
| 91 |
+
output_files = [
|
| 92 |
+
f for f in os.listdir(output_dir)
|
| 93 |
+
if random_stem in f
|
| 94 |
+
]
|
| 95 |
+
if output_files:
|
| 96 |
+
print(f"📄 Output: {output_files[0]}")
|
| 97 |
+
else:
|
| 98 |
+
print(f"\n❌ Test failed for: {random_file}")
|
| 99 |
+
return 1
|
| 100 |
+
|
| 101 |
+
return 0
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"\n❌ Error: {e}")
|
| 105 |
+
return 1
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
import argparse
|
| 110 |
+
parser = argparse.ArgumentParser(description="Test Docling with 1 random PDF from HF cache")
|
| 111 |
+
parser.add_argument("--output", help="Output directory")
|
| 112 |
+
parser.add_argument("--ocr", action="store_true", help="Enable OCR")
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
sys.exit(main(
|
| 116 |
+
output_dir=args.output,
|
| 117 |
+
use_ocr=args.ocr
|
| 118 |
+
))
|
utils/__init__.py
ADDED
|
File without changes
|
utils/helpers.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Sequence, Tuple, TypeVar
|
| 10 |
+
|
| 11 |
+
import yaml
|
| 12 |
+
|
| 13 |
+
T = TypeVar("T")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Filesystem helpers
|
| 17 |
+
def ensure_dir(path: str | os.PathLike) -> str:
|
| 18 |
+
p = Path(path)
|
| 19 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 20 |
+
return str(p)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def read_json(path: str | os.PathLike) -> Any:
|
| 24 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 25 |
+
return json.load(f)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def write_json(data: Any, path: str | os.PathLike, *, indent: int = 2) -> None:
|
| 29 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 31 |
+
json.dump(data, f, ensure_ascii=False, indent=indent)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def read_yaml(path: str | os.PathLike) -> Any:
|
| 35 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 36 |
+
return yaml.safe_load(f)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def write_yaml(data: Any, path: str | os.PathLike) -> None:
|
| 40 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 42 |
+
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# General helpers
|
| 46 |
+
def set_seed(seed: int) -> None:
|
| 47 |
+
random.seed(seed)
|
| 48 |
+
try:
|
| 49 |
+
import numpy as np # type: ignore
|
| 50 |
+
np.random.seed(seed)
|
| 51 |
+
except Exception:
|
| 52 |
+
pass
|
| 53 |
+
try:
|
| 54 |
+
import torch # type: ignore
|
| 55 |
+
torch.manual_seed(seed)
|
| 56 |
+
torch.cuda.manual_seed_all(seed)
|
| 57 |
+
torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
| 58 |
+
torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
| 59 |
+
except Exception:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_env(key: str, default: Optional[str] = None) -> Optional[str]:
|
| 64 |
+
val = os.getenv(key)
|
| 65 |
+
return val if val is not None else default
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def slugify_filename(name: str, max_len: int = 128) -> str:
|
| 69 |
+
base = re.sub(r"[^a-zA-Z0-9._-]+", "-", name).strip("-._")
|
| 70 |
+
return base[:max_len]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def safe_stem(path: str | os.PathLike) -> str:
|
| 74 |
+
p = Path(path)
|
| 75 |
+
return p.stem
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def batched(iterable: Iterable[T], batch_size: int) -> Iterator[List[T]]:
|
| 79 |
+
batch: List[T] = []
|
| 80 |
+
for item in iterable:
|
| 81 |
+
batch.append(item)
|
| 82 |
+
if len(batch) >= batch_size:
|
| 83 |
+
yield batch
|
| 84 |
+
batch = []
|
| 85 |
+
if batch:
|
| 86 |
+
yield batch
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Timing and retry utilities
|
| 90 |
+
def timeit(func: Callable[..., T]) -> Callable[..., T]:
|
| 91 |
+
def wrapper(*args: Any, **kwargs: Any) -> T:
|
| 92 |
+
start = time.perf_counter()
|
| 93 |
+
try:
|
| 94 |
+
return func(*args, **kwargs)
|
| 95 |
+
finally:
|
| 96 |
+
elapsed = (time.perf_counter() - start) * 1000
|
| 97 |
+
print(f"⏱️ {func.__name__} took {elapsed:.2f} ms")
|
| 98 |
+
return wrapper
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def retry(
|
| 102 |
+
exceptions: Tuple[type[BaseException], ...] = (Exception,),
|
| 103 |
+
tries: int = 3,
|
| 104 |
+
delay: float = 0.5,
|
| 105 |
+
backoff: float = 2.0,
|
| 106 |
+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
| 107 |
+
def decorator(fn: Callable[..., T]) -> Callable[..., T]:
|
| 108 |
+
def inner(*args: Any, **kwargs: Any) -> T:
|
| 109 |
+
_tries, _delay = tries, delay
|
| 110 |
+
while _tries > 1:
|
| 111 |
+
try:
|
| 112 |
+
return fn(*args, **kwargs)
|
| 113 |
+
except exceptions:
|
| 114 |
+
time.sleep(_delay)
|
| 115 |
+
_tries -= 1
|
| 116 |
+
_delay *= backoff
|
| 117 |
+
return fn(*args, **kwargs)
|
| 118 |
+
return inner
|
| 119 |
+
return decorator
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# Text utilities helpful for RAG
|
| 123 |
+
def normalize_text(text: str) -> str:
|
| 124 |
+
text = text.replace("\u00A0", " ") # non-breaking space
|
| 125 |
+
text = re.sub(r"\s+", " ", text)
|
| 126 |
+
return text.strip()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def strip_markdown(text: str) -> str:
|
| 130 |
+
# very light-weight markdown stripper for indexing
|
| 131 |
+
text = re.sub(r"`{1,3}[^`]*`{1,3}", " ", text) # code spans/blocks
|
| 132 |
+
text = re.sub(r"\[(.*?)\]\((.*?)\)", r"\1", text) # links
|
| 133 |
+
text = re.sub(r"[#>*_~`]+", " ", text) # punctuation markers
|
| 134 |
+
return normalize_text(text)
|
| 135 |
+
|
| 136 |
+
|
utils/logger.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _get_level(level: Optional[str | int]) -> int:
|
| 7 |
+
if isinstance(level, int):
|
| 8 |
+
return level
|
| 9 |
+
if isinstance(level, str):
|
| 10 |
+
try:
|
| 11 |
+
return getattr(logging, level.upper())
|
| 12 |
+
except AttributeError:
|
| 13 |
+
return logging.INFO
|
| 14 |
+
# ENV override
|
| 15 |
+
env_level = os.getenv("LOG_LEVEL")
|
| 16 |
+
if env_level:
|
| 17 |
+
return getattr(logging, env_level.upper(), logging.INFO)
|
| 18 |
+
return logging.INFO
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def setup_root_logger(level: Optional[str | int] = None) -> None:
|
| 22 |
+
"""Configure root logger once. Safe to call multiple times."""
|
| 23 |
+
if getattr(setup_root_logger, "_configured", False):
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
resolved = _get_level(level)
|
| 27 |
+
logging.basicConfig(
|
| 28 |
+
level=resolved,
|
| 29 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 30 |
+
)
|
| 31 |
+
setup_root_logger._configured = True # type: ignore[attr-defined]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_logger(name: Optional[str] = None, level: Optional[str | int] = None) -> logging.Logger:
|
| 35 |
+
"""Create or fetch a module-scoped logger with consistent formatting.
|
| 36 |
+
|
| 37 |
+
- Honors LOG_LEVEL env if level not provided.
|
| 38 |
+
- Does not add duplicate handlers on repeated calls.
|
| 39 |
+
"""
|
| 40 |
+
setup_root_logger(level)
|
| 41 |
+
logger = logging.getLogger(name if name else __name__)
|
| 42 |
+
if level is not None:
|
| 43 |
+
logger.setLevel(_get_level(level))
|
| 44 |
+
return logger
|
| 45 |
+
|
| 46 |
+
|
utils/metrics.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import time
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class RollingAverage:
|
| 9 |
+
window: int = 100
|
| 10 |
+
values: List[float] = field(default_factory=list)
|
| 11 |
+
|
| 12 |
+
def add(self, x: float) -> None:
|
| 13 |
+
self.values.append(x)
|
| 14 |
+
if len(self.values) > self.window:
|
| 15 |
+
self.values.pop(0)
|
| 16 |
+
|
| 17 |
+
def mean(self) -> float:
|
| 18 |
+
return sum(self.values) / len(self.values) if self.values else 0.0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LatencyTracker:
|
| 22 |
+
def __init__(self, name: str = "latency_ms", sink: Optional[RollingAverage] = None):
|
| 23 |
+
self.name = name
|
| 24 |
+
self.sink = sink or RollingAverage()
|
| 25 |
+
self._start = 0.0
|
| 26 |
+
|
| 27 |
+
def __enter__(self):
|
| 28 |
+
self._start = time.perf_counter()
|
| 29 |
+
return self
|
| 30 |
+
|
| 31 |
+
def __exit__(self, exc_type, exc, tb):
|
| 32 |
+
elapsed_ms = (time.perf_counter() - self._start) * 1000
|
| 33 |
+
self.sink.add(elapsed_ms)
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def avg_ms(self) -> float:
|
| 37 |
+
return self.sink.mean()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _safe_div(num: float, den: float) -> float:
|
| 41 |
+
return num / den if den else 0.0
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Retrieval metrics for RAG
|
| 45 |
+
def hit_rate_at_k(retrieved: Sequence[Sequence[str]], gold: Sequence[Sequence[str]], k: int = 5) -> float:
|
| 46 |
+
hits = 0
|
| 47 |
+
total = len(retrieved)
|
| 48 |
+
for preds, truths in zip(retrieved, gold):
|
| 49 |
+
topk = set(preds[:k])
|
| 50 |
+
truths_set = set(truths)
|
| 51 |
+
hits += 1 if topk & truths_set else 0
|
| 52 |
+
return _safe_div(hits, total)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def recall_at_k(retrieved: Sequence[Sequence[str]], gold: Sequence[Sequence[str]], k: int = 5) -> float:
|
| 56 |
+
total_recall = 0.0
|
| 57 |
+
total = len(retrieved)
|
| 58 |
+
for preds, truths in zip(retrieved, gold):
|
| 59 |
+
topk = set(preds[:k])
|
| 60 |
+
truths_set = set(truths)
|
| 61 |
+
if truths_set:
|
| 62 |
+
total_recall += len(topk & truths_set) / len(truths_set)
|
| 63 |
+
return _safe_div(total_recall, total)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def mrr_at_k(retrieved: Sequence[Sequence[str]], gold: Sequence[Sequence[str]], k: int = 5) -> float:
|
| 67 |
+
mrr = 0.0
|
| 68 |
+
total = len(retrieved)
|
| 69 |
+
for preds, truths in zip(retrieved, gold):
|
| 70 |
+
truths_set = set(truths)
|
| 71 |
+
rr = 0.0
|
| 72 |
+
for rank, pid in enumerate(preds[:k], start=1):
|
| 73 |
+
if pid in truths_set:
|
| 74 |
+
rr = 1.0 / rank
|
| 75 |
+
break
|
| 76 |
+
mrr += rr
|
| 77 |
+
return _safe_div(mrr, total)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def evaluate_retrieval(
|
| 81 |
+
retrieved: Sequence[Sequence[str]],
|
| 82 |
+
gold: Sequence[Sequence[str]],
|
| 83 |
+
k: int = 5,
|
| 84 |
+
) -> Dict[str, float]:
|
| 85 |
+
return {
|
| 86 |
+
"hit_rate@k": hit_rate_at_k(retrieved, gold, k),
|
| 87 |
+
"recall@k": recall_at_k(retrieved, gold, k),
|
| 88 |
+
"mrr@k": mrr_at_k(retrieved, gold, k),
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Token utilities
|
| 93 |
+
def estimate_tokens(text: str, model_name: Optional[str] = None) -> int:
|
| 94 |
+
try:
|
| 95 |
+
import tiktoken # type: ignore
|
| 96 |
+
enc = tiktoken.get_encoding("cl100k_base") if not model_name else tiktoken.encoding_for_model(model_name)
|
| 97 |
+
return len(enc.encode(text))
|
| 98 |
+
except Exception:
|
| 99 |
+
# Fallback: rough heuristic ~ 4 chars per token
|
| 100 |
+
return max(1, math.ceil(len(text) / 4))
|
| 101 |
+
|
| 102 |
+
|