Upload 10 files
Browse files- .gitattributes +35 -35
- .github/workflows/sync_HFSpace.yml +18 -0
- .gitignore +162 -0
- README.md +58 -13
- app.py +343 -0
- indexing.py +83 -0
- prompt_template.json +5 -0
- requirements-dev.txt +2 -0
- requirements.txt +14 -0
- retrieval.py +114 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/sync_HFSpace.yml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to Hugging Face hub
|
| 2 |
+
on:
|
| 3 |
+
# to run this workflow manually from the Actions tab
|
| 4 |
+
workflow_dispatch:
|
| 5 |
+
|
| 6 |
+
jobs:
|
| 7 |
+
sync-to-hub:
|
| 8 |
+
runs-on: ubuntu-latest
|
| 9 |
+
steps:
|
| 10 |
+
- uses: actions/checkout@v4
|
| 11 |
+
with:
|
| 12 |
+
fetch-depth: 0
|
| 13 |
+
lfs: true
|
| 14 |
+
- name: Push to hub
|
| 15 |
+
env:
|
| 16 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 17 |
+
run: git push https://cvachet:$HF_TOKEN@huggingface.co/spaces/cvachet/pdf-chatbot main
|
| 18 |
+
|
.gitignore
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
.idea/
|
README.md
CHANGED
|
@@ -1,13 +1,58 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: PDF
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned:
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PDF Chatbot
|
| 3 |
+
emoji: 🌍
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.16.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
[](https://www.python.org/downloads/)
|
| 14 |
+
[](https://github.com/psf/black)
|
| 15 |
+
[](https://github.com/pylint-dev/pylint)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
**Aim: PDF-based AI chatbot with retrieval augmented generation**
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
**Architecture / Tech stack:**
|
| 23 |
+
- Front-end:
|
| 24 |
+
- user interface via Gradio library
|
| 25 |
+
- Back-end:
|
| 26 |
+
- HuggingFace embeddings
|
| 27 |
+
- HuggingFace Inference API for open-source LLMs
|
| 28 |
+
- Chromadb vector database
|
| 29 |
+
- LangChain conversational retrieval chain
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
You can try out the deployed [Hugging Face Space](https://huggingface.co/spaces/cvachet/pdf-chatbot)!
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
----
|
| 36 |
+
|
| 37 |
+
### Overview
|
| 38 |
+
|
| 39 |
+
**Description:**
|
| 40 |
+
This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. The user interface explicitely shows multiple steps to help understand the RAG workflow. This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes. It leverages small LLM models to run directly on CPU hardware.
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
**Available open-source LLMs:**
|
| 44 |
+
- Meta Llama series
|
| 45 |
+
- Alibaba Qwen2.5 series
|
| 46 |
+
- Mistral AI models
|
| 47 |
+
- Microsoft Phi-3.5 series
|
| 48 |
+
- Google Gemma models
|
| 49 |
+
- HuggingFace zephyr and SmolLM series
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
### Local execution
|
| 53 |
+
|
| 54 |
+
Command line for execution:
|
| 55 |
+
> python3 app.py
|
| 56 |
+
|
| 57 |
+
The Gradio web application should now be accessible at http://localhost:7860
|
| 58 |
+
|
app.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PDF-based chatbot with Retrieval-Augmented Generation
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
import indexing
|
| 11 |
+
import retrieval
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# default_persist_directory = './chroma_HF/'
|
| 15 |
+
list_llm = [
|
| 16 |
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 17 |
+
"microsoft/Phi-3.5-mini-instruct",
|
| 18 |
+
"meta-llama/Llama-3.1-8B-Instruct",
|
| 19 |
+
"meta-llama/Llama-3.2-3B-Instruct",
|
| 20 |
+
"meta-llama/Llama-3.2-1B-Instruct",
|
| 21 |
+
"HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
| 22 |
+
"HuggingFaceH4/zephyr-7b-beta",
|
| 23 |
+
"HuggingFaceH4/zephyr-7b-gemma-v0.1",
|
| 24 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 25 |
+
"google/gemma-2-2b-it",
|
| 26 |
+
"google/gemma-2-9b-it",
|
| 27 |
+
"Qwen/Qwen2.5-1.5B-Instruct",
|
| 28 |
+
"Qwen/Qwen2.5-3B-Instruct",
|
| 29 |
+
"Qwen/Qwen2.5-7B-Instruct",
|
| 30 |
+
]
|
| 31 |
+
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Load environment file - HuggingFace API key
|
| 35 |
+
def retrieve_api():
|
| 36 |
+
"""Retrieve HuggingFace API Key"""
|
| 37 |
+
_ = load_dotenv()
|
| 38 |
+
global huggingfacehub_api_token
|
| 39 |
+
huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Initialize database
|
| 43 |
+
def initialize_database(
|
| 44 |
+
list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()
|
| 45 |
+
):
|
| 46 |
+
"""Initialize database"""
|
| 47 |
+
|
| 48 |
+
# Create list of documents (when valid)
|
| 49 |
+
list_file_path = [x.name for x in list_file_obj if x is not None]
|
| 50 |
+
|
| 51 |
+
# Create collection_name for vector database
|
| 52 |
+
progress(0.1, desc="Creating collection name...")
|
| 53 |
+
collection_name = indexing.create_collection_name(list_file_path[0])
|
| 54 |
+
|
| 55 |
+
progress(0.25, desc="Loading document...")
|
| 56 |
+
# Load document and create splits
|
| 57 |
+
doc_splits = indexing.load_doc(list_file_path, chunk_size, chunk_overlap)
|
| 58 |
+
|
| 59 |
+
# Create or load vector database
|
| 60 |
+
progress(0.5, desc="Generating vector database...")
|
| 61 |
+
|
| 62 |
+
# global vector_db
|
| 63 |
+
vector_db = indexing.create_db(doc_splits, collection_name)
|
| 64 |
+
|
| 65 |
+
return vector_db, collection_name, "Complete!"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Initialize LLM
|
| 69 |
+
def initialize_llm(
|
| 70 |
+
llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()
|
| 71 |
+
):
|
| 72 |
+
"""Initialize LLM"""
|
| 73 |
+
|
| 74 |
+
# print("llm_option",llm_option)
|
| 75 |
+
llm_name = list_llm[llm_option]
|
| 76 |
+
print("llm_name: ", llm_name)
|
| 77 |
+
qa_chain = retrieval.initialize_llmchain(
|
| 78 |
+
llm_name, huggingfacehub_api_token, llm_temperature, max_tokens, top_k, vector_db, progress
|
| 79 |
+
)
|
| 80 |
+
return qa_chain, "Complete!"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# Chatbot conversation
|
| 84 |
+
def conversation(qa_chain, message, history):
|
| 85 |
+
"""Chatbot conversation"""
|
| 86 |
+
|
| 87 |
+
qa_chain, new_history, response_sources = retrieval.invoke_qa_chain(
|
| 88 |
+
qa_chain, message, history
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Format output gradio components
|
| 92 |
+
response_source1 = response_sources[0].page_content.strip()
|
| 93 |
+
response_source2 = response_sources[1].page_content.strip()
|
| 94 |
+
response_source3 = response_sources[2].page_content.strip()
|
| 95 |
+
# Langchain sources are zero-based
|
| 96 |
+
response_source1_page = response_sources[0].metadata["page"] + 1
|
| 97 |
+
response_source2_page = response_sources[1].metadata["page"] + 1
|
| 98 |
+
response_source3_page = response_sources[2].metadata["page"] + 1
|
| 99 |
+
|
| 100 |
+
return (
|
| 101 |
+
qa_chain,
|
| 102 |
+
gr.update(value=""),
|
| 103 |
+
new_history,
|
| 104 |
+
response_source1,
|
| 105 |
+
response_source1_page,
|
| 106 |
+
response_source2,
|
| 107 |
+
response_source2_page,
|
| 108 |
+
response_source3,
|
| 109 |
+
response_source3_page,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
SPACE_TITLE = """
|
| 114 |
+
<center><h2>PDF-based chatbot</center></h2>
|
| 115 |
+
<h3>Ask any questions about your PDF documents</h3>
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
SPACE_INFO = """
|
| 119 |
+
<b>Description:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
|
| 120 |
+
The user interface explicitely shows multiple steps to help understand the RAG workflow.
|
| 121 |
+
This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
|
| 122 |
+
<br><b>Notes:</b> Updated space with more recent LLM models (Qwen 2.5, Llama 3.2, SmolLM2 series)
|
| 123 |
+
<br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Gradio User Interface
|
| 128 |
+
def gradio_ui():
|
| 129 |
+
"""Gradio User Interface"""
|
| 130 |
+
|
| 131 |
+
with gr.Blocks(theme="base") as demo:
|
| 132 |
+
vector_db = gr.State()
|
| 133 |
+
qa_chain = gr.State()
|
| 134 |
+
collection_name = gr.State()
|
| 135 |
+
|
| 136 |
+
gr.Markdown(SPACE_TITLE)
|
| 137 |
+
gr.Markdown(SPACE_INFO)
|
| 138 |
+
|
| 139 |
+
with gr.Tab("Step 1 - Upload PDF"):
|
| 140 |
+
with gr.Row():
|
| 141 |
+
document = gr.File(
|
| 142 |
+
height=200,
|
| 143 |
+
file_count="multiple",
|
| 144 |
+
file_types=[".pdf"],
|
| 145 |
+
interactive=True,
|
| 146 |
+
label="Upload your PDF documents (single or multiple)",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
with gr.Tab("Step 2 - Process document"):
|
| 150 |
+
with gr.Row():
|
| 151 |
+
db_btn = gr.Radio(
|
| 152 |
+
["ChromaDB"],
|
| 153 |
+
label="Vector database type",
|
| 154 |
+
value="ChromaDB",
|
| 155 |
+
type="index",
|
| 156 |
+
info="Choose your vector database",
|
| 157 |
+
)
|
| 158 |
+
with gr.Accordion("Advanced options - Document text splitter", open=False):
|
| 159 |
+
with gr.Row():
|
| 160 |
+
slider_chunk_size = gr.Slider(
|
| 161 |
+
minimum=100,
|
| 162 |
+
maximum=1000,
|
| 163 |
+
value=600,
|
| 164 |
+
step=20,
|
| 165 |
+
label="Chunk size",
|
| 166 |
+
info="Chunk size",
|
| 167 |
+
interactive=True,
|
| 168 |
+
)
|
| 169 |
+
with gr.Row():
|
| 170 |
+
slider_chunk_overlap = gr.Slider(
|
| 171 |
+
minimum=10,
|
| 172 |
+
maximum=200,
|
| 173 |
+
value=40,
|
| 174 |
+
step=10,
|
| 175 |
+
label="Chunk overlap",
|
| 176 |
+
info="Chunk overlap",
|
| 177 |
+
interactive=True,
|
| 178 |
+
)
|
| 179 |
+
with gr.Row():
|
| 180 |
+
db_progress = gr.Textbox(
|
| 181 |
+
label="Vector database initialization", value="None"
|
| 182 |
+
)
|
| 183 |
+
with gr.Row():
|
| 184 |
+
db_btn = gr.Button("Generate vector database")
|
| 185 |
+
|
| 186 |
+
with gr.Tab("Step 3 - Initialize QA chain"):
|
| 187 |
+
with gr.Row():
|
| 188 |
+
llm_btn = gr.Radio(
|
| 189 |
+
list_llm_simple,
|
| 190 |
+
label="LLM models",
|
| 191 |
+
value=list_llm_simple[0],
|
| 192 |
+
type="index",
|
| 193 |
+
info="Choose your LLM model",
|
| 194 |
+
)
|
| 195 |
+
with gr.Accordion("Advanced options - LLM model", open=False):
|
| 196 |
+
with gr.Row():
|
| 197 |
+
slider_temperature = gr.Slider(
|
| 198 |
+
minimum=0.01,
|
| 199 |
+
maximum=1.0,
|
| 200 |
+
value=0.7,
|
| 201 |
+
step=0.1,
|
| 202 |
+
label="Temperature",
|
| 203 |
+
info="Model temperature",
|
| 204 |
+
interactive=True,
|
| 205 |
+
)
|
| 206 |
+
with gr.Row():
|
| 207 |
+
slider_maxtokens = gr.Slider(
|
| 208 |
+
minimum=224,
|
| 209 |
+
maximum=4096,
|
| 210 |
+
value=1024,
|
| 211 |
+
step=32,
|
| 212 |
+
label="Max Tokens",
|
| 213 |
+
info="Model max tokens",
|
| 214 |
+
interactive=True,
|
| 215 |
+
)
|
| 216 |
+
with gr.Row():
|
| 217 |
+
slider_topk = gr.Slider(
|
| 218 |
+
minimum=1,
|
| 219 |
+
maximum=10,
|
| 220 |
+
value=3,
|
| 221 |
+
step=1,
|
| 222 |
+
label="top-k samples",
|
| 223 |
+
info="Model top-k samples",
|
| 224 |
+
interactive=True,
|
| 225 |
+
)
|
| 226 |
+
with gr.Row():
|
| 227 |
+
llm_progress = gr.Textbox(value="None", label="QA chain initialization")
|
| 228 |
+
with gr.Row():
|
| 229 |
+
qachain_btn = gr.Button("Initialize Question Answering chain")
|
| 230 |
+
|
| 231 |
+
with gr.Tab("Step 4 - Chatbot"):
|
| 232 |
+
chatbot = gr.Chatbot(height=300)
|
| 233 |
+
with gr.Accordion("Advanced - Document references", open=False):
|
| 234 |
+
with gr.Row():
|
| 235 |
+
doc_source1 = gr.Textbox(
|
| 236 |
+
label="Reference 1", lines=2, container=True, scale=20
|
| 237 |
+
)
|
| 238 |
+
source1_page = gr.Number(label="Page", scale=1)
|
| 239 |
+
with gr.Row():
|
| 240 |
+
doc_source2 = gr.Textbox(
|
| 241 |
+
label="Reference 2", lines=2, container=True, scale=20
|
| 242 |
+
)
|
| 243 |
+
source2_page = gr.Number(label="Page", scale=1)
|
| 244 |
+
with gr.Row():
|
| 245 |
+
doc_source3 = gr.Textbox(
|
| 246 |
+
label="Reference 3", lines=2, container=True, scale=20
|
| 247 |
+
)
|
| 248 |
+
source3_page = gr.Number(label="Page", scale=1)
|
| 249 |
+
with gr.Row():
|
| 250 |
+
msg = gr.Textbox(
|
| 251 |
+
placeholder="Type message (e.g. 'Can you summarize this document in one paragraph?')",
|
| 252 |
+
container=True,
|
| 253 |
+
)
|
| 254 |
+
with gr.Row():
|
| 255 |
+
submit_btn = gr.Button("Submit message")
|
| 256 |
+
clear_btn = gr.ClearButton(
|
| 257 |
+
components=[msg, chatbot], value="Clear conversation"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Preprocessing events
|
| 261 |
+
db_btn.click(
|
| 262 |
+
initialize_database,
|
| 263 |
+
inputs=[document, slider_chunk_size, slider_chunk_overlap],
|
| 264 |
+
outputs=[vector_db, collection_name, db_progress],
|
| 265 |
+
)
|
| 266 |
+
qachain_btn.click(
|
| 267 |
+
initialize_llm,
|
| 268 |
+
inputs=[
|
| 269 |
+
llm_btn,
|
| 270 |
+
slider_temperature,
|
| 271 |
+
slider_maxtokens,
|
| 272 |
+
slider_topk,
|
| 273 |
+
vector_db,
|
| 274 |
+
],
|
| 275 |
+
outputs=[qa_chain, llm_progress],
|
| 276 |
+
).then(
|
| 277 |
+
lambda: [None, "", 0, "", 0, "", 0],
|
| 278 |
+
inputs=None,
|
| 279 |
+
outputs=[
|
| 280 |
+
chatbot,
|
| 281 |
+
doc_source1,
|
| 282 |
+
source1_page,
|
| 283 |
+
doc_source2,
|
| 284 |
+
source2_page,
|
| 285 |
+
doc_source3,
|
| 286 |
+
source3_page,
|
| 287 |
+
],
|
| 288 |
+
queue=False,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Chatbot events
|
| 292 |
+
msg.submit(
|
| 293 |
+
conversation,
|
| 294 |
+
inputs=[qa_chain, msg, chatbot],
|
| 295 |
+
outputs=[
|
| 296 |
+
qa_chain,
|
| 297 |
+
msg,
|
| 298 |
+
chatbot,
|
| 299 |
+
doc_source1,
|
| 300 |
+
source1_page,
|
| 301 |
+
doc_source2,
|
| 302 |
+
source2_page,
|
| 303 |
+
doc_source3,
|
| 304 |
+
source3_page,
|
| 305 |
+
],
|
| 306 |
+
queue=False,
|
| 307 |
+
)
|
| 308 |
+
submit_btn.click(
|
| 309 |
+
conversation,
|
| 310 |
+
inputs=[qa_chain, msg, chatbot],
|
| 311 |
+
outputs=[
|
| 312 |
+
qa_chain,
|
| 313 |
+
msg,
|
| 314 |
+
chatbot,
|
| 315 |
+
doc_source1,
|
| 316 |
+
source1_page,
|
| 317 |
+
doc_source2,
|
| 318 |
+
source2_page,
|
| 319 |
+
doc_source3,
|
| 320 |
+
source3_page,
|
| 321 |
+
],
|
| 322 |
+
queue=False,
|
| 323 |
+
)
|
| 324 |
+
clear_btn.click(
|
| 325 |
+
lambda: [None, "", 0, "", 0, "", 0],
|
| 326 |
+
inputs=None,
|
| 327 |
+
outputs=[
|
| 328 |
+
chatbot,
|
| 329 |
+
doc_source1,
|
| 330 |
+
source1_page,
|
| 331 |
+
doc_source2,
|
| 332 |
+
source2_page,
|
| 333 |
+
doc_source3,
|
| 334 |
+
source3_page,
|
| 335 |
+
],
|
| 336 |
+
queue=False,
|
| 337 |
+
)
|
| 338 |
+
demo.queue().launch(debug=True)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
if __name__ == "__main__":
|
| 342 |
+
retrieve_api()
|
| 343 |
+
gradio_ui()
|
indexing.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Indexing with vector database
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import chromadb
|
| 9 |
+
|
| 10 |
+
from unidecode import unidecode
|
| 11 |
+
|
| 12 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 13 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 14 |
+
from langchain_chroma import Chroma
|
| 15 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Load PDF document and create doc splits
|
| 20 |
+
def load_doc(list_file_path, chunk_size, chunk_overlap):
|
| 21 |
+
"""Load PDF document and create doc splits"""
|
| 22 |
+
|
| 23 |
+
loaders = [PyPDFLoader(x) for x in list_file_path]
|
| 24 |
+
pages = []
|
| 25 |
+
for loader in loaders:
|
| 26 |
+
pages.extend(loader.load())
|
| 27 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 28 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
| 29 |
+
)
|
| 30 |
+
doc_splits = text_splitter.split_documents(pages)
|
| 31 |
+
return doc_splits
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Generate collection name for vector database
|
| 35 |
+
# - Use filepath as input, ensuring unicode text
|
| 36 |
+
# - Handle multiple languages (arabic, chinese)
|
| 37 |
+
def create_collection_name(filepath):
|
| 38 |
+
"""Create collection name for vector database"""
|
| 39 |
+
|
| 40 |
+
# Extract filename without extension
|
| 41 |
+
collection_name = Path(filepath).stem
|
| 42 |
+
# Fix potential issues from naming convention
|
| 43 |
+
## Remove space
|
| 44 |
+
collection_name = collection_name.replace(" ", "-")
|
| 45 |
+
## ASCII transliterations of Unicode text
|
| 46 |
+
collection_name = unidecode(collection_name)
|
| 47 |
+
## Remove special characters
|
| 48 |
+
collection_name = re.sub("[^A-Za-z0-9]+", "-", collection_name)
|
| 49 |
+
## Limit length to 50 characters
|
| 50 |
+
collection_name = collection_name[:50]
|
| 51 |
+
## Minimum length of 3 characters
|
| 52 |
+
if len(collection_name) < 3:
|
| 53 |
+
collection_name = collection_name + "xyz"
|
| 54 |
+
## Enforce start and end as alphanumeric character
|
| 55 |
+
if not collection_name[0].isalnum():
|
| 56 |
+
collection_name = "A" + collection_name[1:]
|
| 57 |
+
if not collection_name[-1].isalnum():
|
| 58 |
+
collection_name = collection_name[:-1] + "Z"
|
| 59 |
+
print("\n\nFilepath: ", filepath)
|
| 60 |
+
print("Collection name: ", collection_name)
|
| 61 |
+
return collection_name
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Create vector database
|
| 65 |
+
def create_db(splits, collection_name):
|
| 66 |
+
"""Create embeddings and vector database"""
|
| 67 |
+
|
| 68 |
+
embedding = HuggingFaceEmbeddings(
|
| 69 |
+
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
| 70 |
+
# model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 71 |
+
# model_kwargs={"device": "cpu"},
|
| 72 |
+
# encode_kwargs={'normalize_embeddings': False}
|
| 73 |
+
)
|
| 74 |
+
chromadb.api.client.SharedSystemClient.clear_system_cache()
|
| 75 |
+
new_client = chromadb.EphemeralClient()
|
| 76 |
+
vectordb = Chroma.from_documents(
|
| 77 |
+
documents=splits,
|
| 78 |
+
embedding=embedding,
|
| 79 |
+
client=new_client,
|
| 80 |
+
collection_name=collection_name,
|
| 81 |
+
# persist_directory=default_persist_directory
|
| 82 |
+
)
|
| 83 |
+
return vectordb
|
prompt_template.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"title": "System prompt",
|
| 3 |
+
"prompt": "You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise. Question: {question} \\n Context: {context} \\n Helpful Answer:"
|
| 4 |
+
}
|
| 5 |
+
|
requirements-dev.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pylint
|
| 2 |
+
black
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers[torch]
|
| 2 |
+
sentence-transformers
|
| 3 |
+
langchain
|
| 4 |
+
langchain-community
|
| 5 |
+
langchain-huggingface
|
| 6 |
+
langchain-chroma
|
| 7 |
+
huggingface-hub
|
| 8 |
+
tqdm
|
| 9 |
+
accelerate
|
| 10 |
+
pypdf
|
| 11 |
+
chromadb
|
| 12 |
+
unidecode
|
| 13 |
+
gradio
|
| 14 |
+
python-dotenv
|
retrieval.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM chain retrieval
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
|
| 9 |
+
from langchain.memory import ConversationBufferMemory
|
| 10 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
| 11 |
+
from langchain_core.prompts import PromptTemplate
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Add system template for RAG application
|
| 15 |
+
PROMPT_TEMPLATE = """
|
| 16 |
+
You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end.
|
| 17 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise.
|
| 18 |
+
Question: {question}
|
| 19 |
+
Context: {context}
|
| 20 |
+
Helpful Answer:
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Initialize langchain LLM chain
|
| 25 |
+
def initialize_llmchain(
|
| 26 |
+
llm_model,
|
| 27 |
+
huggingfacehub_api_token,
|
| 28 |
+
temperature,
|
| 29 |
+
max_tokens,
|
| 30 |
+
top_k,
|
| 31 |
+
vector_db,
|
| 32 |
+
progress=gr.Progress(),
|
| 33 |
+
):
|
| 34 |
+
"""Initialize Langchain LLM chain"""
|
| 35 |
+
|
| 36 |
+
progress(0.1, desc="Initializing HF tokenizer...")
|
| 37 |
+
# HuggingFaceHub uses HF inference endpoints
|
| 38 |
+
progress(0.5, desc="Initializing HF Hub...")
|
| 39 |
+
# Use of trust_remote_code as model_kwargs
|
| 40 |
+
# Warning: langchain issue
|
| 41 |
+
# URL: https://github.com/langchain-ai/langchain/issues/6080
|
| 42 |
+
|
| 43 |
+
llm = HuggingFaceEndpoint(
|
| 44 |
+
repo_id=llm_model,
|
| 45 |
+
task="text-generation",
|
| 46 |
+
temperature=temperature,
|
| 47 |
+
max_new_tokens=max_tokens,
|
| 48 |
+
top_k=top_k,
|
| 49 |
+
huggingfacehub_api_token=huggingfacehub_api_token,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
progress(0.75, desc="Defining buffer memory...")
|
| 53 |
+
memory = ConversationBufferMemory(
|
| 54 |
+
memory_key="chat_history", output_key="answer", return_messages=True
|
| 55 |
+
)
|
| 56 |
+
# retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
|
| 57 |
+
retriever = vector_db.as_retriever()
|
| 58 |
+
|
| 59 |
+
progress(0.8, desc="Defining retrieval chain...")
|
| 60 |
+
with open('prompt_template.json', 'r') as file:
|
| 61 |
+
system_prompt = json.load(file)
|
| 62 |
+
prompt_template = system_prompt["prompt"]
|
| 63 |
+
rag_prompt = PromptTemplate(
|
| 64 |
+
template=prompt_template, input_variables=["context", "question"]
|
| 65 |
+
)
|
| 66 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 67 |
+
llm,
|
| 68 |
+
retriever=retriever,
|
| 69 |
+
chain_type="stuff",
|
| 70 |
+
memory=memory,
|
| 71 |
+
combine_docs_chain_kwargs={"prompt": rag_prompt},
|
| 72 |
+
return_source_documents=True,
|
| 73 |
+
# return_generated_question=False,
|
| 74 |
+
verbose=False,
|
| 75 |
+
)
|
| 76 |
+
progress(0.9, desc="Done!")
|
| 77 |
+
|
| 78 |
+
return qa_chain
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def format_chat_history(message, chat_history):
|
| 82 |
+
"""Format chat history for llm chain"""
|
| 83 |
+
|
| 84 |
+
formatted_chat_history = []
|
| 85 |
+
for user_message, bot_message in chat_history:
|
| 86 |
+
formatted_chat_history.append(f"User: {user_message}")
|
| 87 |
+
formatted_chat_history.append(f"Assistant: {bot_message}")
|
| 88 |
+
return formatted_chat_history
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def invoke_qa_chain(qa_chain, message, history):
|
| 92 |
+
"""Invoke question-answering chain"""
|
| 93 |
+
|
| 94 |
+
formatted_chat_history = format_chat_history(message, history)
|
| 95 |
+
# print("formatted_chat_history",formatted_chat_history)
|
| 96 |
+
|
| 97 |
+
# Generate response using QA chain
|
| 98 |
+
response = qa_chain.invoke(
|
| 99 |
+
{"question": message, "chat_history": formatted_chat_history}
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
response_sources = response["source_documents"]
|
| 103 |
+
|
| 104 |
+
response_answer = response["answer"]
|
| 105 |
+
if response_answer.find("Helpful Answer:") != -1:
|
| 106 |
+
response_answer = response_answer.split("Helpful Answer:")[-1]
|
| 107 |
+
|
| 108 |
+
# Append user message and response to chat history
|
| 109 |
+
new_history = history + [(message, response_answer)]
|
| 110 |
+
|
| 111 |
+
# print ('chat response: ', response_answer)
|
| 112 |
+
# print('DB source', response_sources)
|
| 113 |
+
|
| 114 |
+
return qa_chain, new_history, response_sources
|