Spaces:
Sleeping
Sleeping
Commit ยท
9630ae8
0
Parent(s):
Initial commit - RFPilot experiment
Browse files- .dockerignore +7 -0
- .gitignore +234 -0
- Dockerfile +30 -0
- LICENSE +21 -0
- README.md +30 -0
- app.py +284 -0
- app.py.old +284 -0
- data/eval_dataset.json +59 -0
- data/eval_template.csv +3 -0
- requirements.txt +29 -0
- src/analyze_results.py +0 -0
- src/compare_models.py +325 -0
- src/create_eval_dataset.py +247 -0
- src/eval_dataset.py +33 -0
- src/generator/generator_gguf.py +598 -0
- src/generator/generator_gguf_base.py +516 -0
- src/generator/generator_gguf_no_rag.py +396 -0
- src/retriever/main.py +67 -0
- src/retriever/retriever.py +313 -0
- src/utils/config.py +193 -0
.dockerignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.git
|
| 4 |
+
.env
|
| 5 |
+
.venv
|
| 6 |
+
*.log
|
| 7 |
+
EOF
|
.gitignore
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
# ๋ชจ๋ธ & DB
|
| 210 |
+
chroma_db/
|
| 211 |
+
models/
|
| 212 |
+
*.gguf
|
| 213 |
+
.cache/
|
| 214 |
+
|
| 215 |
+
# Python
|
| 216 |
+
__pycache__/
|
| 217 |
+
*.py[cod]
|
| 218 |
+
*$py.class
|
| 219 |
+
*.so
|
| 220 |
+
.Python
|
| 221 |
+
env/
|
| 222 |
+
venv/
|
| 223 |
+
|
| 224 |
+
# IDE
|
| 225 |
+
.vscode/
|
| 226 |
+
.idea/
|
| 227 |
+
|
| 228 |
+
# OS
|
| 229 |
+
.DS_Store
|
| 230 |
+
Thumbs.db
|
| 231 |
+
|
| 232 |
+
# Env
|
| 233 |
+
.env
|
| 234 |
+
.env.local
|
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Python ์ค์น
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
python3.10 \
|
| 8 |
+
python3-pip \
|
| 9 |
+
git \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
RUN pip install --no-cache-dir --upgrade pip
|
| 13 |
+
|
| 14 |
+
# requirements.txt ๋ณต์ฌ
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
|
| 17 |
+
# llama-cpp-python (wheel)
|
| 18 |
+
RUN pip install --no-cache-dir \
|
| 19 |
+
llama-cpp-python==0.2.90 \
|
| 20 |
+
--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121
|
| 21 |
+
|
| 22 |
+
# ๋๋จธ์ง ํจํค์ง
|
| 23 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 24 |
+
|
| 25 |
+
# ์์ค ๋ณต์ฌ
|
| 26 |
+
COPY . .
|
| 27 |
+
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
CMD ["python3", "app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Dongjin
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: RFPilot Model Comparison
|
| 3 |
+
emoji: ๐ฌ
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# QLoRA_RAG_test
|
| 14 |
+
---
|
| 15 |
+
์ด ํ๋ก์ ํธ๋ RFP ๋ฌธ์์์ฝ RAG ์ฑ๋ด ํ๋ก์ ํธ์ ํ์ ์ฐ๊ตฌ ์
๋๋ค.
|
| 16 |
+
|
| 17 |
+
## ๋ฌธ์
|
| 18 |
+
๊ธฐ์กด ์๋น์ค ๊ตฌ์กฐ์์ Fine-Tuning๋ ๋ชจ๋ธ์ RAG ์์คํ
์ ์ ์ฉํ์๋ ๊ฒ์ด ๊ณผ์ ํฉ์ ์ผ๊ธฐํ๋์ง, ์ด๋ค ํจ๊ณผ๊ฐ ์๋์ง ํ์ธ ํ์ง ๋ชปํ์๋ค.
|
| 19 |
+
|
| 20 |
+
## ์คํ ์ ์ฐจ
|
| 21 |
+
|
| 22 |
+
- QLoRA ๋ ๋ชจ๋ธ์ ์ค๋นํ๋ค.
|
| 23 |
+
- Fine-Tuning ํ์ง ์์ ์๋ณธ ๋ชจ๋ธ์ ์ค๋นํ๋ค.
|
| 24 |
+
- ํ๊ฐ ๋ฐ์ดํฐ์
์ ์์ฑํ๋ค.
|
| 25 |
+
- Fine-Tuning์ ํ ๊ฒฝ์ฐ, Fine-Tuning์ ํ์ง ์๊ณ RAG๋ง ์ ์ฉํ ๊ฒฝ์ฐ, ๋ ๋ค ์ ์ฉํ ๊ฒฝ์ฐ๋ฅผ ๋๋ ํ
์คํธ๋ฅผ ํด๋ณธ๋ค.
|
| 26 |
+
- ๊ฒฐ๊ณผ๋ฅผ ํ์ธ ํ๋ค.
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## ๊ฒฐ๊ณผ
|
app.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HuggingFace Space์ฉ ์คํ ์ฑ
|
| 3 |
+
|
| 4 |
+
Gradio๋ฅผ ์ฌ์ฉํ์ฌ ์น UI์์ ์คํ์ ์คํํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํฉ๋๋ค.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
# ํ๋ก์ ํธ ๊ฒฝ๋ก ์ค์
|
| 16 |
+
import sys
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 18 |
+
|
| 19 |
+
from src.eval_dataset import EvalDataset
|
| 20 |
+
from src.compare_models import ModelComparison
|
| 21 |
+
from src.analyze_results import ResultAnalyzer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ExperimentApp:
|
| 25 |
+
"""์คํ ์ฑ ํด๋์ค"""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.experiment = None
|
| 29 |
+
self.latest_result_file = None
|
| 30 |
+
|
| 31 |
+
def setup_environment(self, api_key: str) -> str:
|
| 32 |
+
"""ํ๊ฒฝ ์ค์ """
|
| 33 |
+
if not api_key:
|
| 34 |
+
return "โ OpenAI API ํค๋ฅผ ์
๋ ฅํด์ฃผ์ธ์."
|
| 35 |
+
|
| 36 |
+
os.environ['OPENAI_API_KEY'] = api_key
|
| 37 |
+
os.environ['USE_MODEL_HUB'] = 'true'
|
| 38 |
+
os.environ['GGUF_N_GPU_LAYERS'] = '35'
|
| 39 |
+
|
| 40 |
+
return "โ
ํ๊ฒฝ ์ค์ ์๋ฃ!"
|
| 41 |
+
|
| 42 |
+
def run_experiment(
|
| 43 |
+
self,
|
| 44 |
+
api_key: str,
|
| 45 |
+
distribution: str,
|
| 46 |
+
progress=gr.Progress()
|
| 47 |
+
) -> tuple:
|
| 48 |
+
"""์คํ ์คํ"""
|
| 49 |
+
try:
|
| 50 |
+
# ํ๊ฒฝ ์ค์
|
| 51 |
+
setup_msg = self.setup_environment(api_key)
|
| 52 |
+
if "โ" in setup_msg:
|
| 53 |
+
return setup_msg, None, None
|
| 54 |
+
|
| 55 |
+
progress(0.1, desc="ํ๊ฒฝ ์ค์ ์๋ฃ...")
|
| 56 |
+
|
| 57 |
+
# Config ๋ก๋
|
| 58 |
+
from src.utils.config import RAGConfig
|
| 59 |
+
config = RAGConfig()
|
| 60 |
+
|
| 61 |
+
progress(0.2, desc="์คํ ์ด๊ธฐํ ์ค...")
|
| 62 |
+
|
| 63 |
+
# ์คํ ์ด๊ธฐํ
|
| 64 |
+
self.experiment = ModelComparison(
|
| 65 |
+
config=config,
|
| 66 |
+
output_dir="./experiments/results"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
progress(0.3, desc="๋ชจ๋ธ ๋ก๋ฉ ์ค... (3-5๋ถ ์์)")
|
| 70 |
+
|
| 71 |
+
# ๋ชจ๋ธ ๋ก๋
|
| 72 |
+
self.experiment.load_models()
|
| 73 |
+
|
| 74 |
+
progress(0.5, desc="์คํ ์คํ ์ค... (10-20๋ถ ์์)")
|
| 75 |
+
|
| 76 |
+
# ์คํ ์คํ
|
| 77 |
+
results = self.experiment.run_experiment(
|
| 78 |
+
distribution=distribution.lower(),
|
| 79 |
+
save_results=True
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
progress(0.9, desc="๊ฒฐ๊ณผ ๋ถ์ ์ค...")
|
| 83 |
+
|
| 84 |
+
# ์ต์ ๊ฒฐ๊ณผ ํ์ผ ์ฐพ๊ธฐ
|
| 85 |
+
result_files = sorted(
|
| 86 |
+
Path("./experiments/results").glob("results_*.json"),
|
| 87 |
+
reverse=True
|
| 88 |
+
)
|
| 89 |
+
self.latest_result_file = str(result_files[0]) if result_files else None
|
| 90 |
+
|
| 91 |
+
# ์์ฝ ์์ฑ
|
| 92 |
+
summary = self._generate_summary(results)
|
| 93 |
+
|
| 94 |
+
# CSV ์์ฑ
|
| 95 |
+
df = self._results_to_dataframe(results)
|
| 96 |
+
|
| 97 |
+
progress(1.0, desc="์๋ฃ!")
|
| 98 |
+
|
| 99 |
+
return "โ
์คํ ์๋ฃ!", summary, df
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return f"โ ์คํ ์คํจ: {str(e)}", None, None
|
| 103 |
+
|
| 104 |
+
def _generate_summary(self, results: dict) -> str:
|
| 105 |
+
"""์์ฝ ์์ฑ"""
|
| 106 |
+
summary = "=" * 60 + "\n"
|
| 107 |
+
summary += "์คํ ๊ฒฐ๊ณผ ์์ฝ\n"
|
| 108 |
+
summary += "=" * 60 + "\n\n"
|
| 109 |
+
|
| 110 |
+
metadata = results['metadata']
|
| 111 |
+
summary += f"ํ์์คํฌํ: {metadata['timestamp']}\n"
|
| 112 |
+
summary += f"๋ถํฌ: {metadata['distribution']}\n"
|
| 113 |
+
summary += f"๋ชจ๋ธ: {', '.join(metadata['models'])}\n"
|
| 114 |
+
summary += f"์ด ์ง๋ฌธ ์: {metadata['total_queries']}\n\n"
|
| 115 |
+
|
| 116 |
+
# ๊ฐ ๋ถํฌ๋ณ ์์ฝ
|
| 117 |
+
for dist_type, dist_results in results['results'].items():
|
| 118 |
+
summary += f"\n{'='*60}\n"
|
| 119 |
+
summary += f"{dist_type.upper()}\n"
|
| 120 |
+
summary += f"{'='*60}\n\n"
|
| 121 |
+
|
| 122 |
+
# ๋ชจ๋ธ๋ณ ํต๊ณ
|
| 123 |
+
model_stats = {}
|
| 124 |
+
for result in dist_results:
|
| 125 |
+
model = result['model']
|
| 126 |
+
if model not in model_stats:
|
| 127 |
+
model_stats[model] = []
|
| 128 |
+
model_stats[model].append(result)
|
| 129 |
+
|
| 130 |
+
for model, model_results in model_stats.items():
|
| 131 |
+
success_count = sum(1 for r in model_results if r['success'])
|
| 132 |
+
avg_time = sum(r['elapsed_time'] for r in model_results if r['success']) / max(success_count, 1)
|
| 133 |
+
|
| 134 |
+
summary += f"[{model}]\n"
|
| 135 |
+
summary += f" ์ฑ๊ณต: {success_count}/{len(model_results)}\n"
|
| 136 |
+
summary += f" ํ๊ท ์๊ฐ: {avg_time:.3f}์ด\n\n"
|
| 137 |
+
|
| 138 |
+
return summary
|
| 139 |
+
|
| 140 |
+
def _results_to_dataframe(self, results: dict) -> pd.DataFrame:
|
| 141 |
+
"""๊ฒฐ๊ณผ๋ฅผ DataFrame์ผ๋ก ๋ณํ"""
|
| 142 |
+
all_rows = []
|
| 143 |
+
|
| 144 |
+
for dist_type, dist_results in results['results'].items():
|
| 145 |
+
for result in dist_results:
|
| 146 |
+
row = {
|
| 147 |
+
'distribution': dist_type,
|
| 148 |
+
'model': result['model'],
|
| 149 |
+
'query': result['query'],
|
| 150 |
+
'success': result['success'],
|
| 151 |
+
'elapsed_time': result['elapsed_time'],
|
| 152 |
+
'total_tokens': result.get('usage', {}).get('total_tokens', 0)
|
| 153 |
+
}
|
| 154 |
+
all_rows.append(row)
|
| 155 |
+
|
| 156 |
+
return pd.DataFrame(all_rows)
|
| 157 |
+
|
| 158 |
+
def analyze_results(self) -> tuple:
|
| 159 |
+
"""๊ฒฐ๊ณผ ๋ถ์"""
|
| 160 |
+
if not self.latest_result_file:
|
| 161 |
+
return "โ ๋จผ์ ์คํ์ ์คํํด์ฃผ์ธ์.", None, None, None, None
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
analyzer = ResultAnalyzer(self.latest_result_file)
|
| 165 |
+
|
| 166 |
+
# ๊ทธ๋ํ ์์ฑ
|
| 167 |
+
analyzer.plot_time_comparison()
|
| 168 |
+
analyzer.plot_token_comparison()
|
| 169 |
+
analyzer.plot_rag_usage()
|
| 170 |
+
analyzer.plot_overfitting_analysis()
|
| 171 |
+
|
| 172 |
+
# ๊ทธ๋ํ ํ์ผ ๊ฒฝ๋ก
|
| 173 |
+
analysis_dir = Path(self.latest_result_file).parent / "analysis"
|
| 174 |
+
|
| 175 |
+
time_plot = str(analysis_dir / "time_comparison.png")
|
| 176 |
+
token_plot = str(analysis_dir / "token_comparison.png")
|
| 177 |
+
rag_plot = str(analysis_dir / "rag_usage.png")
|
| 178 |
+
overfitting_plot = str(analysis_dir / "overfitting_analysis.png")
|
| 179 |
+
|
| 180 |
+
return (
|
| 181 |
+
"โ
๋ถ์ ์๋ฃ!",
|
| 182 |
+
time_plot if Path(time_plot).exists() else None,
|
| 183 |
+
token_plot if Path(token_plot).exists() else None,
|
| 184 |
+
rag_plot if Path(rag_plot).exists() else None,
|
| 185 |
+
overfitting_plot if Path(overfitting_plot).exists() else None
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
return f"โ ๋ถ์ ์คํจ: {str(e)}", None, None, None, None
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Gradio ์ธํฐํ์ด์ค ์์ฑ
|
| 193 |
+
def create_interface():
|
| 194 |
+
"""Gradio ์ธํฐํ์ด์ค ์์ฑ"""
|
| 195 |
+
app = ExperimentApp()
|
| 196 |
+
|
| 197 |
+
with gr.Blocks(title="RFPilot ๋ชจ๋ธ ๋น๊ต ์คํ") as demo:
|
| 198 |
+
gr.Markdown("""
|
| 199 |
+
# ๐ฌ RFPilot ๋ชจ๋ธ ๋น๊ต ์คํ
|
| 200 |
+
|
| 201 |
+
3๊ฐ์ง ๋ชจ๋ธ(QLoRA+RAG, QLoRA ๋จ๋
, Base+RAG)์ ์ฑ๋ฅ์ ๋น๊ตํฉ๋๋ค.
|
| 202 |
+
|
| 203 |
+
โ ๏ธ **์ฃผ์**: ์คํ ์คํ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฝ๋๋ค (10-20๋ถ).
|
| 204 |
+
""")
|
| 205 |
+
|
| 206 |
+
with gr.Tab("๐ ์คํ ์คํ"):
|
| 207 |
+
api_key_input = gr.Textbox(
|
| 208 |
+
label="OpenAI API Key",
|
| 209 |
+
type="password",
|
| 210 |
+
placeholder="sk-..."
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
distribution_input = gr.Radio(
|
| 214 |
+
choices=["All", "In", "Out"],
|
| 215 |
+
value="All",
|
| 216 |
+
label="๋ถํฌ ์ ํ"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
run_btn = gr.Button("์คํ ์์", variant="primary")
|
| 220 |
+
|
| 221 |
+
status_output = gr.Textbox(label="์ํ", lines=2)
|
| 222 |
+
summary_output = gr.Textbox(label="์์ฝ", lines=20)
|
| 223 |
+
results_output = gr.Dataframe(label="๊ฒฐ๊ณผ")
|
| 224 |
+
|
| 225 |
+
run_btn.click(
|
| 226 |
+
fn=app.run_experiment,
|
| 227 |
+
inputs=[api_key_input, distribution_input],
|
| 228 |
+
outputs=[status_output, summary_output, results_output]
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
with gr.Tab("๐ ๊ฒฐ๊ณผ ๋ถ์"):
|
| 232 |
+
analyze_btn = gr.Button("๋ถ์ ์์", variant="primary")
|
| 233 |
+
|
| 234 |
+
analyze_status = gr.Textbox(label="์ํ")
|
| 235 |
+
|
| 236 |
+
with gr.Row():
|
| 237 |
+
time_plot = gr.Image(label="์๋ต ์๊ฐ ๋น๊ต")
|
| 238 |
+
token_plot = gr.Image(label="ํ ํฐ ์ฌ์ฉ๋ ๋น๊ต")
|
| 239 |
+
|
| 240 |
+
with gr.Row():
|
| 241 |
+
rag_plot = gr.Image(label="RAG ์ฌ์ฉ ํจํด")
|
| 242 |
+
overfitting_plot = gr.Image(label="๊ณผ์ ํฉ ๋ถ์")
|
| 243 |
+
|
| 244 |
+
analyze_btn.click(
|
| 245 |
+
fn=app.analyze_results,
|
| 246 |
+
outputs=[
|
| 247 |
+
analyze_status,
|
| 248 |
+
time_plot,
|
| 249 |
+
token_plot,
|
| 250 |
+
rag_plot,
|
| 251 |
+
overfitting_plot
|
| 252 |
+
]
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
with gr.Tab("โน๏ธ ์ ๋ณด"):
|
| 256 |
+
gr.Markdown("""
|
| 257 |
+
## ๐ ๋น๊ต ๋ชจ๋ธ
|
| 258 |
+
|
| 259 |
+
| ๋ชจ๋ธ | ์ค๋ช
|
|
| 260 |
+
|------|------|
|
| 261 |
+
| QLoRA + RAG | ๊ธฐ์กด ์๋น์ค (QLoRA fine-tuning + RAG) |
|
| 262 |
+
| QLoRA ๋จ๋
| RAG ์ ๊ฑฐ (QLoRA๋ง) |
|
| 263 |
+
| Base + RAG | PEFT ์ ๊ฑฐ (Base ๋ชจ๋ธ + RAG) |
|
| 264 |
+
|
| 265 |
+
## ๐ ์ธก์ ์งํ
|
| 266 |
+
|
| 267 |
+
- **๊ณผ์ ํฉ**: In-Distribution vs Out-Distribution ์ฑ๋ฅ ์ฐจ์ด
|
| 268 |
+
- **๋ต๋ณ ์๋**: ํ๊ท ์๋ต ์๊ฐ
|
| 269 |
+
- **ํ ํฐ ์ฌ์ฉ๋**: ํ๊ท ํ ํฐ ์๋น
|
| 270 |
+
- **RAG ์ฌ์ฉ ํจํด**: RAG ํ์ฉ๋
|
| 271 |
+
|
| 272 |
+
## โฑ๏ธ ์์ ์์ ์๊ฐ
|
| 273 |
+
|
| 274 |
+
- ๋ชจ๋ธ ๋ก๋ฉ: 3-5๋ถ
|
| 275 |
+
- ์คํ ์คํ: 10-20๋ถ (25๊ฐ ์ง๋ฌธ)
|
| 276 |
+
- ๊ฒฐ๊ณผ ๋ถ์: 1-2๋ถ
|
| 277 |
+
""")
|
| 278 |
+
|
| 279 |
+
return demo
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
demo = create_interface()
|
| 284 |
+
demo.launch(share=True)
|
app.py.old
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# HuggingFace Space์ฉ ์คํ ์ฑ
|
| 3 |
+
|
| 4 |
+
# Gradio๋ฅผ ์ฌ์ฉํ์ฌ ์น UI์์ ์คํ์ ์คํํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํฉ๋๋ค.
|
| 5 |
+
# """
|
| 6 |
+
|
| 7 |
+
# import gradio as gr
|
| 8 |
+
# import os
|
| 9 |
+
# import json
|
| 10 |
+
# import pandas as pd
|
| 11 |
+
# from pathlib import Path
|
| 12 |
+
# import matplotlib.pyplot as plt
|
| 13 |
+
# from datetime import datetime
|
| 14 |
+
|
| 15 |
+
# # ํ๋ก์ ํธ ๊ฒฝ๋ก ์ค์
|
| 16 |
+
# import sys
|
| 17 |
+
# sys.path.insert(0, str(Path(__file__).parent))
|
| 18 |
+
|
| 19 |
+
# from experiments.eval_dataset import EvalDataset
|
| 20 |
+
# from experiments.compare_models import ModelComparison
|
| 21 |
+
# from experiments.analyze_results import ResultAnalyzer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# class ExperimentApp:
|
| 25 |
+
# """์คํ ์ฑ ํด๋์ค"""
|
| 26 |
+
|
| 27 |
+
# def __init__(self):
|
| 28 |
+
# self.experiment = None
|
| 29 |
+
# self.latest_result_file = None
|
| 30 |
+
|
| 31 |
+
# def setup_environment(self, api_key: str) -> str:
|
| 32 |
+
# """ํ๊ฒฝ ์ค์ """
|
| 33 |
+
# if not api_key:
|
| 34 |
+
# return "โ OpenAI API ํค๋ฅผ ์
๋ ฅํด์ฃผ์ธ์."
|
| 35 |
+
|
| 36 |
+
# os.environ['OPENAI_API_KEY'] = api_key
|
| 37 |
+
# os.environ['USE_MODEL_HUB'] = 'true'
|
| 38 |
+
# os.environ['GGUF_N_GPU_LAYERS'] = '35'
|
| 39 |
+
|
| 40 |
+
# return "โ
ํ๊ฒฝ ์ค์ ์๋ฃ!"
|
| 41 |
+
|
| 42 |
+
# def run_experiment(
|
| 43 |
+
# self,
|
| 44 |
+
# api_key: str,
|
| 45 |
+
# distribution: str,
|
| 46 |
+
# progress=gr.Progress()
|
| 47 |
+
# ) -> tuple:
|
| 48 |
+
# """์คํ ์คํ"""
|
| 49 |
+
# try:
|
| 50 |
+
# # ํ๊ฒฝ ์ค์
|
| 51 |
+
# setup_msg = self.setup_environment(api_key)
|
| 52 |
+
# if "โ" in setup_msg:
|
| 53 |
+
# return setup_msg, None, None
|
| 54 |
+
|
| 55 |
+
# progress(0.1, desc="ํ๊ฒฝ ์ค์ ์๋ฃ...")
|
| 56 |
+
|
| 57 |
+
# # Config ๋ก๋
|
| 58 |
+
# from src.utils.config import RAGConfig
|
| 59 |
+
# config = RAGConfig()
|
| 60 |
+
|
| 61 |
+
# progress(0.2, desc="์คํ ์ด๊ธฐํ ์ค...")
|
| 62 |
+
|
| 63 |
+
# # ์คํ ์ด๊ธฐํ
|
| 64 |
+
# self.experiment = ModelComparison(
|
| 65 |
+
# config=config,
|
| 66 |
+
# output_dir="./experiments/results"
|
| 67 |
+
# )
|
| 68 |
+
|
| 69 |
+
# progress(0.3, desc="๋ชจ๋ธ ๋ก๋ฉ ์ค... (3-5๋ถ ์์)")
|
| 70 |
+
|
| 71 |
+
# # ๋ชจ๋ธ ๋ก๋
|
| 72 |
+
# self.experiment.load_models()
|
| 73 |
+
|
| 74 |
+
# progress(0.5, desc="์คํ ์คํ ์ค... (10-20๋ถ ์์)")
|
| 75 |
+
|
| 76 |
+
# # ์คํ ์คํ
|
| 77 |
+
# results = self.experiment.run_experiment(
|
| 78 |
+
# distribution=distribution.lower(),
|
| 79 |
+
# save_results=True
|
| 80 |
+
# )
|
| 81 |
+
|
| 82 |
+
# progress(0.9, desc="๊ฒฐ๊ณผ ๋ถ์ ์ค...")
|
| 83 |
+
|
| 84 |
+
# # ์ต์ ๊ฒฐ๊ณผ ํ์ผ ์ฐพ๊ธฐ
|
| 85 |
+
# result_files = sorted(
|
| 86 |
+
# Path("./experiments/results").glob("results_*.json"),
|
| 87 |
+
# reverse=True
|
| 88 |
+
# )
|
| 89 |
+
# self.latest_result_file = str(result_files[0]) if result_files else None
|
| 90 |
+
|
| 91 |
+
# # ์์ฝ ์์ฑ
|
| 92 |
+
# summary = self._generate_summary(results)
|
| 93 |
+
|
| 94 |
+
# # CSV ์์ฑ
|
| 95 |
+
# df = self._results_to_dataframe(results)
|
| 96 |
+
|
| 97 |
+
# progress(1.0, desc="์๋ฃ!")
|
| 98 |
+
|
| 99 |
+
# return "โ
์คํ ์๋ฃ!", summary, df
|
| 100 |
+
|
| 101 |
+
# except Exception as e:
|
| 102 |
+
# return f"โ ์คํ ์คํจ: {str(e)}", None, None
|
| 103 |
+
|
| 104 |
+
# def _generate_summary(self, results: dict) -> str:
|
| 105 |
+
# """์์ฝ ์์ฑ"""
|
| 106 |
+
# summary = "=" * 60 + "\n"
|
| 107 |
+
# summary += "์คํ ๊ฒฐ๊ณผ ์์ฝ\n"
|
| 108 |
+
# summary += "=" * 60 + "\n\n"
|
| 109 |
+
|
| 110 |
+
# metadata = results['metadata']
|
| 111 |
+
# summary += f"ํ์์คํฌํ: {metadata['timestamp']}\n"
|
| 112 |
+
# summary += f"๋ถํฌ: {metadata['distribution']}\n"
|
| 113 |
+
# summary += f"๋ชจ๋ธ: {', '.join(metadata['models'])}\n"
|
| 114 |
+
# summary += f"์ด ์ง๋ฌธ ์: {metadata['total_queries']}\n\n"
|
| 115 |
+
|
| 116 |
+
# # ๊ฐ ๋ถํฌ๋ณ ์์ฝ
|
| 117 |
+
# for dist_type, dist_results in results['results'].items():
|
| 118 |
+
# summary += f"\n{'='*60}\n"
|
| 119 |
+
# summary += f"{dist_type.upper()}\n"
|
| 120 |
+
# summary += f"{'='*60}\n\n"
|
| 121 |
+
|
| 122 |
+
# # ๋ชจ๋ธ๋ณ ํต๊ณ
|
| 123 |
+
# model_stats = {}
|
| 124 |
+
# for result in dist_results:
|
| 125 |
+
# model = result['model']
|
| 126 |
+
# if model not in model_stats:
|
| 127 |
+
# model_stats[model] = []
|
| 128 |
+
# model_stats[model].append(result)
|
| 129 |
+
|
| 130 |
+
# for model, model_results in model_stats.items():
|
| 131 |
+
# success_count = sum(1 for r in model_results if r['success'])
|
| 132 |
+
# avg_time = sum(r['elapsed_time'] for r in model_results if r['success']) / max(success_count, 1)
|
| 133 |
+
|
| 134 |
+
# summary += f"[{model}]\n"
|
| 135 |
+
# summary += f" ์ฑ๊ณต: {success_count}/{len(model_results)}\n"
|
| 136 |
+
# summary += f" ํ๊ท ์๊ฐ: {avg_time:.3f}์ด\n\n"
|
| 137 |
+
|
| 138 |
+
# return summary
|
| 139 |
+
|
| 140 |
+
# def _results_to_dataframe(self, results: dict) -> pd.DataFrame:
|
| 141 |
+
# """๊ฒฐ๊ณผ๋ฅผ DataFrame์ผ๋ก ๋ณํ"""
|
| 142 |
+
# all_rows = []
|
| 143 |
+
|
| 144 |
+
# for dist_type, dist_results in results['results'].items():
|
| 145 |
+
# for result in dist_results:
|
| 146 |
+
# row = {
|
| 147 |
+
# 'distribution': dist_type,
|
| 148 |
+
# 'model': result['model'],
|
| 149 |
+
# 'query': result['query'],
|
| 150 |
+
# 'success': result['success'],
|
| 151 |
+
# 'elapsed_time': result['elapsed_time'],
|
| 152 |
+
# 'total_tokens': result.get('usage', {}).get('total_tokens', 0)
|
| 153 |
+
# }
|
| 154 |
+
# all_rows.append(row)
|
| 155 |
+
|
| 156 |
+
# return pd.DataFrame(all_rows)
|
| 157 |
+
|
| 158 |
+
# def analyze_results(self) -> tuple:
|
| 159 |
+
# """๊ฒฐ๊ณผ ๋ถ์"""
|
| 160 |
+
# if not self.latest_result_file:
|
| 161 |
+
# return "โ ๋จผ์ ์คํ์ ์คํํด์ฃผ์ธ์.", None, None, None, None
|
| 162 |
+
|
| 163 |
+
# try:
|
| 164 |
+
# analyzer = ResultAnalyzer(self.latest_result_file)
|
| 165 |
+
|
| 166 |
+
# # ๊ทธ๋ํ ์์ฑ
|
| 167 |
+
# analyzer.plot_time_comparison()
|
| 168 |
+
# analyzer.plot_token_comparison()
|
| 169 |
+
# analyzer.plot_rag_usage()
|
| 170 |
+
# analyzer.plot_overfitting_analysis()
|
| 171 |
+
|
| 172 |
+
# # ๊ทธ๋ํ ํ์ผ ๊ฒฝ๋ก
|
| 173 |
+
# analysis_dir = Path(self.latest_result_file).parent / "analysis"
|
| 174 |
+
|
| 175 |
+
# time_plot = str(analysis_dir / "time_comparison.png")
|
| 176 |
+
# token_plot = str(analysis_dir / "token_comparison.png")
|
| 177 |
+
# rag_plot = str(analysis_dir / "rag_usage.png")
|
| 178 |
+
# overfitting_plot = str(analysis_dir / "overfitting_analysis.png")
|
| 179 |
+
|
| 180 |
+
# return (
|
| 181 |
+
# "โ
๋ถ์ ์๋ฃ!",
|
| 182 |
+
# time_plot if Path(time_plot).exists() else None,
|
| 183 |
+
# token_plot if Path(token_plot).exists() else None,
|
| 184 |
+
# rag_plot if Path(rag_plot).exists() else None,
|
| 185 |
+
# overfitting_plot if Path(overfitting_plot).exists() else None
|
| 186 |
+
# )
|
| 187 |
+
|
| 188 |
+
# except Exception as e:
|
| 189 |
+
# return f"โ ๋ถ์ ์คํจ: {str(e)}", None, None, None, None
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# # Gradio ์ธํฐํ์ด์ค ์์ฑ
|
| 193 |
+
# def create_interface():
|
| 194 |
+
# """Gradio ์ธํฐํ์ด์ค ์์ฑ"""
|
| 195 |
+
# app = ExperimentApp()
|
| 196 |
+
|
| 197 |
+
# with gr.Blocks(title="RFPilot ๋ชจ๋ธ ๋น๊ต ์คํ") as demo:
|
| 198 |
+
# gr.Markdown("""
|
| 199 |
+
# # ๐ฌ RFPilot ๋ชจ๋ธ ๋น๊ต ์คํ
|
| 200 |
+
|
| 201 |
+
# 3๊ฐ์ง ๋ชจ๋ธ(QLoRA+RAG, QLoRA ๋จ๋
, Base+RAG)์ ์ฑ๋ฅ์ ๋น๊ตํฉ๋๋ค.
|
| 202 |
+
|
| 203 |
+
# โ ๏ธ **์ฃผ์**: ์คํ ์คํ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฝ๋๋ค (10-20๋ถ).
|
| 204 |
+
# """)
|
| 205 |
+
|
| 206 |
+
# with gr.Tab("๐ ์คํ ์คํ"):
|
| 207 |
+
# api_key_input = gr.Textbox(
|
| 208 |
+
# label="OpenAI API Key",
|
| 209 |
+
# type="password",
|
| 210 |
+
# placeholder="sk-..."
|
| 211 |
+
# )
|
| 212 |
+
|
| 213 |
+
# distribution_input = gr.Radio(
|
| 214 |
+
# choices=["All", "In", "Out"],
|
| 215 |
+
# value="All",
|
| 216 |
+
# label="๋ถํฌ ์ ํ"
|
| 217 |
+
# )
|
| 218 |
+
|
| 219 |
+
# run_btn = gr.Button("์คํ ์์", variant="primary")
|
| 220 |
+
|
| 221 |
+
# status_output = gr.Textbox(label="์ํ", lines=2)
|
| 222 |
+
# summary_output = gr.Textbox(label="์์ฝ", lines=20)
|
| 223 |
+
# results_output = gr.Dataframe(label="๊ฒฐ๊ณผ")
|
| 224 |
+
|
| 225 |
+
# run_btn.click(
|
| 226 |
+
# fn=app.run_experiment,
|
| 227 |
+
# inputs=[api_key_input, distribution_input],
|
| 228 |
+
# outputs=[status_output, summary_output, results_output]
|
| 229 |
+
# )
|
| 230 |
+
|
| 231 |
+
# with gr.Tab("๐ ๊ฒฐ๊ณผ ๋ถ์"):
|
| 232 |
+
# analyze_btn = gr.Button("๋ถ์ ์์", variant="primary")
|
| 233 |
+
|
| 234 |
+
# analyze_status = gr.Textbox(label="์ํ")
|
| 235 |
+
|
| 236 |
+
# with gr.Row():
|
| 237 |
+
# time_plot = gr.Image(label="์๋ต ์๊ฐ ๋น๊ต")
|
| 238 |
+
# token_plot = gr.Image(label="ํ ํฐ ์ฌ์ฉ๋ ๋น๊ต")
|
| 239 |
+
|
| 240 |
+
# with gr.Row():
|
| 241 |
+
# rag_plot = gr.Image(label="RAG ์ฌ์ฉ ํจํด")
|
| 242 |
+
# overfitting_plot = gr.Image(label="๊ณผ์ ํฉ ๋ถ์")
|
| 243 |
+
|
| 244 |
+
# analyze_btn.click(
|
| 245 |
+
# fn=app.analyze_results,
|
| 246 |
+
# outputs=[
|
| 247 |
+
# analyze_status,
|
| 248 |
+
# time_plot,
|
| 249 |
+
# token_plot,
|
| 250 |
+
# rag_plot,
|
| 251 |
+
# overfitting_plot
|
| 252 |
+
# ]
|
| 253 |
+
# )
|
| 254 |
+
|
| 255 |
+
# with gr.Tab("โน๏ธ ์ ๋ณด"):
|
| 256 |
+
# gr.Markdown("""
|
| 257 |
+
# ## ๐ ๋น๊ต ๋ชจ๋ธ
|
| 258 |
+
|
| 259 |
+
# | ๋ชจ๋ธ | ์ค๋ช
|
|
| 260 |
+
# |------|------|
|
| 261 |
+
# | QLoRA + RAG | ๊ธฐ์กด ์๋น์ค (QLoRA fine-tuning + RAG) |
|
| 262 |
+
# | QLoRA ๋จ๋
| RAG ์ ๊ฑฐ (QLoRA๋ง) |
|
| 263 |
+
# | Base + RAG | PEFT ์ ๊ฑฐ (Base ๋ชจ๋ธ + RAG) |
|
| 264 |
+
|
| 265 |
+
# ## ๐ ์ธก์ ์งํ
|
| 266 |
+
|
| 267 |
+
# - **๊ณผ์ ํฉ**: In-Distribution vs Out-Distribution ์ฑ๋ฅ ์ฐจ์ด
|
| 268 |
+
# - **๋ต๋ณ ์๋**: ํ๊ท ์๋ต ์๊ฐ
|
| 269 |
+
# - **ํ ํฐ ์ฌ์ฉ๋**: ํ๊ท ํ ํฐ ์๋น
|
| 270 |
+
# - **RAG ์ฌ์ฉ ํจํด**: RAG ํ์ฉ๋
|
| 271 |
+
|
| 272 |
+
# ## โฑ๏ธ ์์ ์์ ์๊ฐ
|
| 273 |
+
|
| 274 |
+
# - ๋ชจ๋ธ ๋ก๋ฉ: 3-5๋ถ
|
| 275 |
+
# - ์คํ ์คํ: 10-20๋ถ (25๊ฐ ์ง๋ฌธ)
|
| 276 |
+
# - ๊ฒฐ๊ณผ ๋ถ์: 1-2๋ถ
|
| 277 |
+
# """)
|
| 278 |
+
|
| 279 |
+
# return demo
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# if __name__ == "__main__":
|
| 283 |
+
# demo = create_interface()
|
| 284 |
+
# demo.launch(share=True)
|
data/eval_dataset.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"version": "1.0",
|
| 4 |
+
"description": "RFPilot ํ๊ฐ ๋ฐ์ดํฐ์
",
|
| 5 |
+
"created_by": "manual_annotation"
|
| 6 |
+
},
|
| 7 |
+
"in_distribution": [
|
| 8 |
+
{
|
| 9 |
+
"query": "์ฌ์
์ ์์ ์ ์ถ ๋ง๊ฐ์ผ์ ์ธ์ ์ธ๊ฐ์?",
|
| 10 |
+
"expected_answer": "2024๋
3์ 15์ผ๊น์ง์
๋๋ค.",
|
| 11 |
+
"category": "deadline",
|
| 12 |
+
"expected_type": "document",
|
| 13 |
+
"source_doc": "RFP_2024_001.hwp",
|
| 14 |
+
"metadata": {
|
| 15 |
+
"difficulty": "easy"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"query": "์ ์ ์์ฒญ์์ ์ ์ถ ์๋ฅ๋ ๋ฌด์์ธ๊ฐ์?",
|
| 20 |
+
"expected_answer": "๊ธฐ์ ์ ์์, ๊ฐ๊ฒฉ์ ์์, ์ฌ์
์๋ฑ๋ก์ฆ, ํ์ฌ์๊ฐ์๊ฐ ํ์ํฉ๋๋ค.",
|
| 21 |
+
"category": "requirements",
|
| 22 |
+
"expected_type": "document",
|
| 23 |
+
"source_doc": "RFP_2024_001.hwp",
|
| 24 |
+
"metadata": {
|
| 25 |
+
"difficulty": "medium"
|
| 26 |
+
}
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"query": "์ฌ์
์์ฐ ๊ท๋ชจ๋ ์ผ๋ง์ธ๊ฐ์?",
|
| 30 |
+
"expected_answer": "์ด 5์ต์์
๋๋ค.",
|
| 31 |
+
"category": "budget",
|
| 32 |
+
"expected_type": "document",
|
| 33 |
+
"source_doc": "RFP_2024_002.hwp",
|
| 34 |
+
"metadata": {
|
| 35 |
+
"difficulty": "easy"
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"out_distribution": [
|
| 40 |
+
{
|
| 41 |
+
"query": "ํ๊ตญ์ ์๋๋ ์ด๋์ธ๊ฐ์?",
|
| 42 |
+
"expected_answer": "์์ธ์
๋๋ค.",
|
| 43 |
+
"category": "general_knowledge",
|
| 44 |
+
"expected_type": "out_of_scope",
|
| 45 |
+
"metadata": {
|
| 46 |
+
"difficulty": "easy"
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"query": "ํ์ด์ฌ์์ ๋ฆฌ์คํธ์ ํํ์ ์ฐจ์ด๋ ๋ฌด์์ธ๊ฐ์?",
|
| 51 |
+
"expected_answer": "๋ฆฌ์คํธ๋ ๊ฐ๋ณ(mutable)์ด๊ณ , ํํ์ ๋ถ๋ณ(immutable)์
๋๋ค.",
|
| 52 |
+
"category": "programming",
|
| 53 |
+
"expected_type": "out_of_scope",
|
| 54 |
+
"metadata": {
|
| 55 |
+
"difficulty": "medium"
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
]
|
| 59 |
+
}
|
data/eval_template.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
distribution,query,expected_answer,category,source_doc,metadata
|
| 2 |
+
in_distribution,์ฌ์
์ ์์ ์ ์ถ ๋ง๊ฐ์ผ์ ์ธ์ ์ธ๊ฐ์?,2024๋
3์ 15์ผ๊น์ง์
๋๋ค.,deadline,RFP_2024_001.hwp,"{""difficulty"": ""easy""}"
|
| 3 |
+
out_distribution,ํ๊ตญ์ ์๋๋ ์ด๋์ธ๊ฐ์?,์์ธ์
๋๋ค.,general_knowledge,,"{""difficulty"": ""easy""}"
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
python-dotenv>=1.0.0
|
| 3 |
+
openai>=1.0.0
|
| 4 |
+
|
| 5 |
+
# LLM & RAG
|
| 6 |
+
llama-cpp-python>=0.2.90
|
| 7 |
+
chromadb>=0.4.0
|
| 8 |
+
sentence-transformers>=2.2.0
|
| 9 |
+
rank-bm25>=0.2.2
|
| 10 |
+
huggingface-hub>=0.19.0
|
| 11 |
+
|
| 12 |
+
# ๋ฌธ์ ์ฒ๋ฆฌ
|
| 13 |
+
pypdf>=3.17.0
|
| 14 |
+
python-docx>=1.1.0
|
| 15 |
+
olefile>=0.47
|
| 16 |
+
|
| 17 |
+
# ๋ฐ์ดํฐ ์ฒ๋ฆฌ
|
| 18 |
+
pandas>=2.0.0
|
| 19 |
+
numpy>=1.24.0
|
| 20 |
+
|
| 21 |
+
# ์๊ฐํ
|
| 22 |
+
matplotlib>=3.7.0
|
| 23 |
+
seaborn>=0.12.0
|
| 24 |
+
|
| 25 |
+
# ์ ํธ๋ฆฌํฐ
|
| 26 |
+
tqdm>=4.65.0
|
| 27 |
+
|
| 28 |
+
# Gradio
|
| 29 |
+
gradio>=6.0.0
|
src/analyze_results.py
ADDED
|
File without changes
|
src/compare_models.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
3๊ฐ์ง ๋ชจ๋ธ ๋น๊ต ์คํ
|
| 3 |
+
|
| 4 |
+
๋น๊ต ๋์:
|
| 5 |
+
1. QLoRA + RAG (๊ธฐ์กด ์๋น์ค)
|
| 6 |
+
2. QLoRA ๋จ๋
(RAG ์ ๊ฑฐ)
|
| 7 |
+
3. Base + RAG (PEFT ์ ๊ฑฐ)
|
| 8 |
+
|
| 9 |
+
์ธก์ ์งํ:
|
| 10 |
+
- ๊ณผ์ ํฉ ์ฌ๋ถ (In-Distribution vs Out-Distribution)
|
| 11 |
+
- ๋ต๋ณ ์๋ (elapsed_time, retrieval_time, generation_time)
|
| 12 |
+
- ํ ํฐ ๊ฐ์ (total_tokens, prompt_tokens, completion_tokens)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import time
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
from typing import Dict, List, Any
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
# ํ๋ก์ ํธ ๋ฃจํธ ๊ฒฝ๋ก ์ถ๊ฐ
|
| 25 |
+
project_root = Path(__file__).parent.parent
|
| 26 |
+
sys.path.insert(0, str(project_root))
|
| 27 |
+
|
| 28 |
+
from src.utils.config import RAGConfig
|
| 29 |
+
from eval_dataset import EvalDataset
|
| 30 |
+
|
| 31 |
+
# ๋ก๊น
์ค์
|
| 32 |
+
logging.basicConfig(
|
| 33 |
+
level=logging.INFO,
|
| 34 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 35 |
+
)
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ModelComparison:
|
| 40 |
+
"""๋ชจ๋ธ ๋น๊ต ์คํ ํด๋์ค"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, config=None, output_dir: str = "./results"):
|
| 43 |
+
"""์ด๊ธฐํ"""
|
| 44 |
+
self.config = config or RAGConfig()
|
| 45 |
+
self.output_dir = Path(output_dir)
|
| 46 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
# ํ์์คํฌํ
|
| 49 |
+
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 50 |
+
|
| 51 |
+
# ๋ฐ์ดํฐ์
|
| 52 |
+
self.dataset = EvalDataset()
|
| 53 |
+
|
| 54 |
+
# ๋ชจ๋ธ ํ์ดํ๋ผ์ธ
|
| 55 |
+
self.pipelines = {}
|
| 56 |
+
|
| 57 |
+
logger.info(f"โ
ModelComparison ์ด๊ธฐํ ์๋ฃ")
|
| 58 |
+
logger.info(f" ๊ฒฐ๊ณผ ์ ์ฅ ๊ฒฝ๋ก: {self.output_dir}")
|
| 59 |
+
|
| 60 |
+
def load_models(self):
|
| 61 |
+
"""3๊ฐ์ง ๋ชจ๋ธ ๋ก๋"""
|
| 62 |
+
logger.info("\n" + "="*60)
|
| 63 |
+
logger.info("๋ชจ๋ธ ๋ก๋ฉ ์์")
|
| 64 |
+
logger.info("="*60)
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
# 1. QLoRA + RAG (๊ธฐ์กด)
|
| 68 |
+
logger.info("\n[1/3] QLoRA + RAG ๋ชจ๋ธ ๋ก๋ฉ...")
|
| 69 |
+
from src.generator.generator_gguf import GGUFRAGPipeline
|
| 70 |
+
self.pipelines['qlora_rag'] = GGUFRAGPipeline(config=self.config)
|
| 71 |
+
logger.info("โ
QLoRA + RAG ๋ก๋ ์๋ฃ")
|
| 72 |
+
|
| 73 |
+
# 2. QLoRA ๋จ๋
(RAG ์ ๊ฑฐ)
|
| 74 |
+
logger.info("\n[2/3] QLoRA ๋จ๋
๋ชจ๋ธ ๋ก๋ฉ...")
|
| 75 |
+
from src.generator.generator_gguf_no_rag import GGUFNoRAGPipeline
|
| 76 |
+
self.pipelines['qlora_only'] = GGUFNoRAGPipeline(config=self.config)
|
| 77 |
+
logger.info("โ
QLoRA ๋จ๋
๋ก๋ ์๋ฃ")
|
| 78 |
+
|
| 79 |
+
# 3. Base + RAG (PEFT ์ ๊ฑฐ)
|
| 80 |
+
logger.info("\n[3/3] Base + RAG ๋ชจ๋ธ ๋ก๋ฉ...")
|
| 81 |
+
from src.generator.generator_gguf_base import GGUFBaseRAGPipeline
|
| 82 |
+
self.pipelines['base_rag'] = GGUFBaseRAGPipeline(config=self.config)
|
| 83 |
+
logger.info("โ
Base + RAG ๋ก๋ ์๋ฃ")
|
| 84 |
+
|
| 85 |
+
logger.info("\n" + "="*60)
|
| 86 |
+
logger.info(f"โ
์ด {len(self.pipelines)}๊ฐ ๋ชจ๋ธ ๋ก๋ ์๋ฃ")
|
| 87 |
+
logger.info("="*60 + "\n")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}")
|
| 91 |
+
import traceback
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
def run_single_query(
|
| 96 |
+
self,
|
| 97 |
+
model_name: str,
|
| 98 |
+
query: str,
|
| 99 |
+
query_info: Dict[str, Any]
|
| 100 |
+
) -> Dict[str, Any]:
|
| 101 |
+
"""๋จ์ผ ์ง๋ฌธ์ ๋ํ ๋ชจ๋ธ ์คํ"""
|
| 102 |
+
pipeline = self.pipelines[model_name]
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
start_time = time.time()
|
| 106 |
+
result = pipeline.generate_answer(query)
|
| 107 |
+
total_time = time.time() - start_time
|
| 108 |
+
|
| 109 |
+
# ๊ฒฐ๊ณผ ์ ๋ฆฌ
|
| 110 |
+
return {
|
| 111 |
+
'model': model_name,
|
| 112 |
+
'query': query,
|
| 113 |
+
'category': query_info.get('category', 'unknown'),
|
| 114 |
+
'expected_type': query_info.get('expected_type', 'unknown'),
|
| 115 |
+
'answer': result['answer'],
|
| 116 |
+
'used_retrieval': result.get('used_retrieval', False),
|
| 117 |
+
'query_type': result.get('query_type', 'unknown'),
|
| 118 |
+
'search_mode': result.get('search_mode', 'none'),
|
| 119 |
+
'elapsed_time': total_time,
|
| 120 |
+
'model_elapsed_time': result.get('elapsed_time', 0),
|
| 121 |
+
'usage': result.get('usage', {}),
|
| 122 |
+
'sources_count': len(result.get('sources', [])),
|
| 123 |
+
'success': True,
|
| 124 |
+
'error': None
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"โ ์ง๋ฌธ ์คํ ์คํจ [{model_name}]: {e}")
|
| 129 |
+
return {
|
| 130 |
+
'model': model_name,
|
| 131 |
+
'query': query,
|
| 132 |
+
'category': query_info.get('category', 'unknown'),
|
| 133 |
+
'expected_type': query_info.get('expected_type', 'unknown'),
|
| 134 |
+
'answer': None,
|
| 135 |
+
'used_retrieval': False,
|
| 136 |
+
'query_type': 'error',
|
| 137 |
+
'search_mode': 'none',
|
| 138 |
+
'elapsed_time': 0,
|
| 139 |
+
'model_elapsed_time': 0,
|
| 140 |
+
'usage': {},
|
| 141 |
+
'sources_count': 0,
|
| 142 |
+
'success': False,
|
| 143 |
+
'error': str(e)
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def run_experiment(
|
| 147 |
+
self,
|
| 148 |
+
distribution: str = 'all',
|
| 149 |
+
save_results: bool = True
|
| 150 |
+
) -> Dict[str, List[Dict[str, Any]]]:
|
| 151 |
+
"""
|
| 152 |
+
์คํ ์คํ
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
distribution: 'in', 'out', 'all'
|
| 156 |
+
save_results: ๊ฒฐ๊ณผ ์ ์ฅ ์ฌ๋ถ
|
| 157 |
+
"""
|
| 158 |
+
logger.info("\n" + "="*60)
|
| 159 |
+
logger.info("์คํ ์์")
|
| 160 |
+
logger.info("="*60)
|
| 161 |
+
|
| 162 |
+
# ๋ฐ์ดํฐ์
์ค๋น
|
| 163 |
+
if distribution == 'in':
|
| 164 |
+
queries_dict = {'in_distribution': self.dataset.get_in_distribution()}
|
| 165 |
+
elif distribution == 'out':
|
| 166 |
+
queries_dict = {'out_distribution': self.dataset.get_out_distribution()}
|
| 167 |
+
else: # 'all'
|
| 168 |
+
queries_dict = self.dataset.get_all_queries()
|
| 169 |
+
|
| 170 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 171 |
+
all_results = {
|
| 172 |
+
'metadata': {
|
| 173 |
+
'timestamp': self.timestamp,
|
| 174 |
+
'distribution': distribution,
|
| 175 |
+
'models': list(self.pipelines.keys()),
|
| 176 |
+
'total_queries': sum(len(v) for v in queries_dict.values())
|
| 177 |
+
},
|
| 178 |
+
'results': {}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# ๊ฐ ๋ถํฌ์ ๋ํด ์คํ
|
| 182 |
+
for dist_type, queries in queries_dict.items():
|
| 183 |
+
logger.info(f"\n{'='*60}")
|
| 184 |
+
logger.info(f"{dist_type.upper()} ์คํ ({len(queries)}๊ฐ ์ง๋ฌธ)")
|
| 185 |
+
logger.info(f"{'='*60}")
|
| 186 |
+
|
| 187 |
+
dist_results = []
|
| 188 |
+
|
| 189 |
+
# ๊ฐ ์ง๋ฌธ์ ๋ํด
|
| 190 |
+
for i, query_info in enumerate(queries, 1):
|
| 191 |
+
query = query_info['query']
|
| 192 |
+
logger.info(f"\n[{i}/{len(queries)}] ์ง๋ฌธ: {query}")
|
| 193 |
+
|
| 194 |
+
# ๊ฐ ๋ชจ๋ธ์ ๋ํด
|
| 195 |
+
for model_name in self.pipelines.keys():
|
| 196 |
+
logger.info(f" โ {model_name} ์คํ ์ค...")
|
| 197 |
+
|
| 198 |
+
result = self.run_single_query(model_name, query, query_info)
|
| 199 |
+
dist_results.append(result)
|
| 200 |
+
|
| 201 |
+
if result['success']:
|
| 202 |
+
logger.info(f" โ
์๋ฃ ({result['elapsed_time']:.2f}์ด)")
|
| 203 |
+
else:
|
| 204 |
+
logger.warning(f" โ ์คํจ: {result['error']}")
|
| 205 |
+
|
| 206 |
+
all_results['results'][dist_type] = dist_results
|
| 207 |
+
|
| 208 |
+
# ๊ฒฐ๊ณผ ์ ์ฅ
|
| 209 |
+
if save_results:
|
| 210 |
+
self._save_results(all_results)
|
| 211 |
+
|
| 212 |
+
logger.info("\n" + "="*60)
|
| 213 |
+
logger.info("โ
์คํ ์๋ฃ")
|
| 214 |
+
logger.info("="*60 + "\n")
|
| 215 |
+
|
| 216 |
+
return all_results
|
| 217 |
+
|
| 218 |
+
def _save_results(self, results: Dict[str, Any]):
|
| 219 |
+
"""๊ฒฐ๊ณผ ์ ์ฅ"""
|
| 220 |
+
# JSON ํ์ผ๋ก ์ ์ฅ
|
| 221 |
+
output_file = self.output_dir / f"results_{self.timestamp}.json"
|
| 222 |
+
|
| 223 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 224 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
| 225 |
+
|
| 226 |
+
logger.info(f"๐ ๊ฒฐ๊ณผ ์ ์ฅ: {output_file}")
|
| 227 |
+
|
| 228 |
+
# ์์ฝ ํต๊ณ ์ ์ฅ
|
| 229 |
+
summary_file = self.output_dir / f"summary_{self.timestamp}.txt"
|
| 230 |
+
self._save_summary(results, summary_file)
|
| 231 |
+
|
| 232 |
+
logger.info(f"๐ ์์ฝ ์ ์ฅ: {summary_file}")
|
| 233 |
+
|
| 234 |
+
def _save_summary(self, results: Dict[str, Any], output_file: Path):
|
| 235 |
+
"""์์ฝ ํต๊ณ ์ ์ฅ"""
|
| 236 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 237 |
+
f.write("="*60 + "\n")
|
| 238 |
+
f.write("์คํ ๊ฒฐ๊ณผ ์์ฝ\n")
|
| 239 |
+
f.write("="*60 + "\n\n")
|
| 240 |
+
|
| 241 |
+
# ๋ฉํ๋ฐ์ดํฐ
|
| 242 |
+
metadata = results['metadata']
|
| 243 |
+
f.write(f"ํ์์คํฌํ: {metadata['timestamp']}\n")
|
| 244 |
+
f.write(f"๋ถํฌ: {metadata['distribution']}\n")
|
| 245 |
+
f.write(f"๋ชจ๋ธ: {', '.join(metadata['models'])}\n")
|
| 246 |
+
f.write(f"์ด ์ง๋ฌธ ์: {metadata['total_queries']}\n\n")
|
| 247 |
+
|
| 248 |
+
# ๊ฐ ๋ถํฌ๋ณ ํต๊ณ
|
| 249 |
+
for dist_type, dist_results in results['results'].items():
|
| 250 |
+
f.write(f"\n{'='*60}\n")
|
| 251 |
+
f.write(f"{dist_type.upper()} ๊ฒฐ๊ณผ\n")
|
| 252 |
+
f.write(f"{'='*60}\n\n")
|
| 253 |
+
|
| 254 |
+
# ๋ชจ๋ธ๋ณ๋ก ๊ทธ๋ฃนํ
|
| 255 |
+
model_stats = {}
|
| 256 |
+
for result in dist_results:
|
| 257 |
+
model = result['model']
|
| 258 |
+
if model not in model_stats:
|
| 259 |
+
model_stats[model] = []
|
| 260 |
+
model_stats[model].append(result)
|
| 261 |
+
|
| 262 |
+
# ๊ฐ ๋ชจ๋ธ๋ณ ํต๊ณ
|
| 263 |
+
for model, model_results in model_stats.items():
|
| 264 |
+
f.write(f"\n[{model}]\n")
|
| 265 |
+
|
| 266 |
+
# ์ฑ๊ณต/์คํจ
|
| 267 |
+
success_count = sum(1 for r in model_results if r['success'])
|
| 268 |
+
f.write(f" ์ฑ๊ณต: {success_count}/{len(model_results)}\n")
|
| 269 |
+
|
| 270 |
+
# ํ๊ท ์๊ฐ
|
| 271 |
+
avg_time = sum(r['elapsed_time'] for r in model_results if r['success']) / max(success_count, 1)
|
| 272 |
+
f.write(f" ํ๊ท ์๊ฐ: {avg_time:.3f}์ด\n")
|
| 273 |
+
|
| 274 |
+
# ํ๊ท ํ ํฐ
|
| 275 |
+
total_tokens = sum(r['usage'].get('total_tokens', 0) for r in model_results if r['success'])
|
| 276 |
+
avg_tokens = total_tokens / max(success_count, 1)
|
| 277 |
+
f.write(f" ํ๊ท ํ ํฐ: {avg_tokens:.1f}\n")
|
| 278 |
+
|
| 279 |
+
# RAG ์ฌ์ฉ๋ฅ
|
| 280 |
+
rag_count = sum(1 for r in model_results if r['used_retrieval'])
|
| 281 |
+
f.write(f" RAG ์ฌ์ฉ: {rag_count}/{len(model_results)} ({rag_count/len(model_results)*100:.1f}%)\n")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def main():
|
| 285 |
+
"""๋ฉ์ธ ํจ์"""
|
| 286 |
+
logger.info("="*60)
|
| 287 |
+
logger.info("RFPilot ๋ชจ๋ธ ๋น๊ต ์คํ")
|
| 288 |
+
logger.info("="*60)
|
| 289 |
+
|
| 290 |
+
# Config ๋ก๋
|
| 291 |
+
config = RAGConfig()
|
| 292 |
+
|
| 293 |
+
# ์คํ ์ด๊ธฐํ
|
| 294 |
+
experiment = ModelComparison(config=config, output_dir="./experiments/results")
|
| 295 |
+
|
| 296 |
+
# ๋ฐ์ดํฐ์
ํ์ธ
|
| 297 |
+
experiment.dataset.print_summary()
|
| 298 |
+
experiment.dataset.print_samples(n=3)
|
| 299 |
+
|
| 300 |
+
# ๋ชจ๋ธ ๋ก๋
|
| 301 |
+
experiment.load_models()
|
| 302 |
+
|
| 303 |
+
# ์คํ ์คํ
|
| 304 |
+
# ์ต์
1: ์ ์ฒด ์คํ
|
| 305 |
+
results = experiment.run_experiment(distribution='all', save_results=True)
|
| 306 |
+
|
| 307 |
+
# ์ต์
2: In-Distribution๋ง
|
| 308 |
+
# results = experiment.run_experiment(distribution='in', save_results=True)
|
| 309 |
+
|
| 310 |
+
# ์ต์
3: Out-Distribution๋ง
|
| 311 |
+
# results = experiment.run_experiment(distribution='out', save_results=True)
|
| 312 |
+
|
| 313 |
+
logger.info(f"\nโ
๋ชจ๋ ์คํ ์๋ฃ!")
|
| 314 |
+
logger.info(f" ๊ฒฐ๊ณผ ์ ์ฅ ์์น: {experiment.output_dir}")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
try:
|
| 319 |
+
main()
|
| 320 |
+
except KeyboardInterrupt:
|
| 321 |
+
logger.info("\nโ ๏ธ ์ฌ์ฉ์์ ์ํด ์ค๋จ๋จ")
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.error(f"\nโ ์คํ ์คํจ: {e}")
|
| 324 |
+
import traceback
|
| 325 |
+
traceback.print_exc()
|
src/create_eval_dataset.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ํ๊ฐ ๋ฐ์ดํฐ์
์์ฑ ๋๊ตฌ
|
| 3 |
+
|
| 4 |
+
์ค์ RFP ๋ฌธ์์์ ์ง๋ฌธ-๋ต๋ณ ์์ ๋ง๋ค์ด
|
| 5 |
+
Ground Truth๊ฐ ์๋ ํ๊ฐ ๋ฐ์ดํฐ์
์ ์์ฑํฉ๋๋ค.
|
| 6 |
+
|
| 7 |
+
์ฌ์ฉ๋ฒ:
|
| 8 |
+
python create_eval_dataset.py --input data/rag_chunks_final.csv --output data/eval_dataset.json
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import csv
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import List, Dict, Any
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EvalDatasetCreator:
|
| 19 |
+
"""ํ๊ฐ ๋ฐ์ดํฐ์
์์ฑ ํด๋์ค"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.dataset = {
|
| 23 |
+
"metadata": {
|
| 24 |
+
"version": "1.0",
|
| 25 |
+
"description": "RFPilot ํ๊ฐ ๋ฐ์ดํฐ์
",
|
| 26 |
+
"created_by": "manual_annotation"
|
| 27 |
+
},
|
| 28 |
+
"in_distribution": [],
|
| 29 |
+
"out_distribution": []
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def add_in_distribution_sample(
|
| 33 |
+
self,
|
| 34 |
+
query: str,
|
| 35 |
+
expected_answer: str,
|
| 36 |
+
category: str,
|
| 37 |
+
source_doc: str = None,
|
| 38 |
+
metadata: Dict[str, Any] = None
|
| 39 |
+
):
|
| 40 |
+
"""In-Distribution ์ํ ์ถ๊ฐ"""
|
| 41 |
+
sample = {
|
| 42 |
+
"query": query,
|
| 43 |
+
"expected_answer": expected_answer,
|
| 44 |
+
"category": category,
|
| 45 |
+
"expected_type": "document",
|
| 46 |
+
"source_doc": source_doc,
|
| 47 |
+
"metadata": metadata or {}
|
| 48 |
+
}
|
| 49 |
+
self.dataset["in_distribution"].append(sample)
|
| 50 |
+
|
| 51 |
+
def add_out_distribution_sample(
|
| 52 |
+
self,
|
| 53 |
+
query: str,
|
| 54 |
+
expected_answer: str,
|
| 55 |
+
category: str,
|
| 56 |
+
metadata: Dict[str, Any] = None
|
| 57 |
+
):
|
| 58 |
+
"""Out-Distribution ์ํ ์ถ๊ฐ"""
|
| 59 |
+
sample = {
|
| 60 |
+
"query": query,
|
| 61 |
+
"expected_answer": expected_answer,
|
| 62 |
+
"category": category,
|
| 63 |
+
"expected_type": "out_of_scope",
|
| 64 |
+
"metadata": metadata or {}
|
| 65 |
+
}
|
| 66 |
+
self.dataset["out_distribution"].append(sample)
|
| 67 |
+
|
| 68 |
+
def create_template_dataset(self):
|
| 69 |
+
"""ํ
ํ๋ฆฟ ๋ฐ์ดํฐ์
์์ฑ (์๋ ์์ฑ์ฉ)"""
|
| 70 |
+
print("๐ ํ
ํ๋ฆฟ ๋ฐ์ดํฐ์
์์ฑ ์ค...")
|
| 71 |
+
|
| 72 |
+
# In-Distribution ํ
ํ๋ฆฟ
|
| 73 |
+
in_dist_templates = [
|
| 74 |
+
{
|
| 75 |
+
"query": "์ฌ์
์ ์์ ์ ์ถ ๋ง๊ฐ์ผ์ ์ธ์ ์ธ๊ฐ์?",
|
| 76 |
+
"expected_answer": "2024๋
3์ 15์ผ๊น์ง์
๋๋ค.", # ์ค์ ๋ฌธ์์์ ์ถ์ถ
|
| 77 |
+
"category": "deadline",
|
| 78 |
+
"source_doc": "RFP_2024_001.hwp",
|
| 79 |
+
"metadata": {"difficulty": "easy"}
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"query": "์ ์ ์์ฒญ์์ ์ ์ถ ์๋ฅ๋ ๋ฌด์์ธ๊ฐ์?",
|
| 83 |
+
"expected_answer": "๊ธฐ์ ์ ์์, ๊ฐ๊ฒฉ์ ์์, ์ฌ์
์๋ฑ๋ก์ฆ, ํ์ฌ์๊ฐ์๊ฐ ํ์ํฉ๋๋ค.",
|
| 84 |
+
"category": "requirements",
|
| 85 |
+
"source_doc": "RFP_2024_001.hwp",
|
| 86 |
+
"metadata": {"difficulty": "medium"}
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"query": "์ฌ์
์์ฐ ๊ท๋ชจ๋ ์ผ๋ง์ธ๊ฐ์?",
|
| 90 |
+
"expected_answer": "์ด 5์ต์์
๋๋ค.",
|
| 91 |
+
"category": "budget",
|
| 92 |
+
"source_doc": "RFP_2024_002.hwp",
|
| 93 |
+
"metadata": {"difficulty": "easy"}
|
| 94 |
+
},
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
# Out-Distribution ํ
ํ๋ฆฟ
|
| 98 |
+
out_dist_templates = [
|
| 99 |
+
{
|
| 100 |
+
"query": "ํ๊ตญ์ ์๋๋ ์ด๋์ธ๊ฐ์?",
|
| 101 |
+
"expected_answer": "์์ธ์
๋๋ค.",
|
| 102 |
+
"category": "general_knowledge",
|
| 103 |
+
"metadata": {"difficulty": "easy"}
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"query": "ํ์ด์ฌ์์ ๋ฆฌ์คํธ์ ํํ์ ์ฐจ์ด๋ ๋ฌด์์ธ๊ฐ์?",
|
| 107 |
+
"expected_answer": "๋ฆฌ์คํธ๋ ๊ฐ๋ณ(mutable)์ด๊ณ , ํํ์ ๋ถ๋ณ(immutable)์
๋๋ค.",
|
| 108 |
+
"category": "programming",
|
| 109 |
+
"metadata": {"difficulty": "medium"}
|
| 110 |
+
},
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
# ๋ฐ์ดํฐ์
์ ์ถ๊ฐ
|
| 114 |
+
for sample in in_dist_templates:
|
| 115 |
+
self.add_in_distribution_sample(**sample)
|
| 116 |
+
|
| 117 |
+
for sample in out_dist_templates:
|
| 118 |
+
self.add_out_distribution_sample(**sample)
|
| 119 |
+
|
| 120 |
+
print(f"โ
ํ
ํ๋ฆฟ ์์ฑ ์๋ฃ")
|
| 121 |
+
print(f" - In-Distribution: {len(in_dist_templates)}๊ฐ")
|
| 122 |
+
print(f" - Out-Distribution: {len(out_dist_templates)}๊ฐ")
|
| 123 |
+
print(f"\nโ ๏ธ ์ด ํ
ํ๋ฆฟ์ ์์ ํ์ฌ ์ค์ ๋ฐ์ดํฐ๋ฅผ ์ฑ์์ฃผ์ธ์!")
|
| 124 |
+
|
| 125 |
+
def load_from_csv(self, csv_path: str):
|
| 126 |
+
"""CSV์์ ๋ฐ์ดํฐ์
๋ก๋"""
|
| 127 |
+
print(f"๐ฅ CSV ๋ก๋ ์ค: {csv_path}")
|
| 128 |
+
|
| 129 |
+
with open(csv_path, 'r', encoding='utf-8') as f:
|
| 130 |
+
reader = csv.DictReader(f)
|
| 131 |
+
for row in reader:
|
| 132 |
+
distribution = row.get('distribution', 'in_distribution')
|
| 133 |
+
|
| 134 |
+
if distribution == 'in_distribution':
|
| 135 |
+
self.add_in_distribution_sample(
|
| 136 |
+
query=row['query'],
|
| 137 |
+
expected_answer=row['expected_answer'],
|
| 138 |
+
category=row['category'],
|
| 139 |
+
source_doc=row.get('source_doc'),
|
| 140 |
+
metadata=json.loads(row.get('metadata', '{}'))
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
self.add_out_distribution_sample(
|
| 144 |
+
query=row['query'],
|
| 145 |
+
expected_answer=row['expected_answer'],
|
| 146 |
+
category=row['category'],
|
| 147 |
+
metadata=json.loads(row.get('metadata', '{}'))
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
print(f"โ
CSV ๋ก๋ ์๋ฃ")
|
| 151 |
+
|
| 152 |
+
def save_json(self, output_path: str):
|
| 153 |
+
"""JSON ํ์์ผ๋ก ์ ์ฅ"""
|
| 154 |
+
output_path = Path(output_path)
|
| 155 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 156 |
+
|
| 157 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 158 |
+
json.dump(self.dataset, f, ensure_ascii=False, indent=2)
|
| 159 |
+
|
| 160 |
+
print(f"๐พ ์ ์ฅ ์๋ฃ: {output_path}")
|
| 161 |
+
|
| 162 |
+
def save_csv_template(self, output_path: str):
|
| 163 |
+
"""์๋ ์์ฑ์ฉ CSV ํ
ํ๋ฆฟ ์ ์ฅ"""
|
| 164 |
+
output_path = Path(output_path)
|
| 165 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 166 |
+
|
| 167 |
+
with open(output_path, 'w', encoding='utf-8', newline='') as f:
|
| 168 |
+
writer = csv.DictWriter(f, fieldnames=[
|
| 169 |
+
'distribution', 'query', 'expected_answer',
|
| 170 |
+
'category', 'source_doc', 'metadata'
|
| 171 |
+
])
|
| 172 |
+
writer.writeheader()
|
| 173 |
+
|
| 174 |
+
# In-Distribution ์์
|
| 175 |
+
writer.writerow({
|
| 176 |
+
'distribution': 'in_distribution',
|
| 177 |
+
'query': '์ฌ์
์ ์์ ์ ์ถ ๋ง๊ฐ์ผ์ ์ธ์ ์ธ๊ฐ์?',
|
| 178 |
+
'expected_answer': '2024๋
3์ 15์ผ๊น์ง์
๋๋ค.',
|
| 179 |
+
'category': 'deadline',
|
| 180 |
+
'source_doc': 'RFP_2024_001.hwp',
|
| 181 |
+
'metadata': '{"difficulty": "easy"}'
|
| 182 |
+
})
|
| 183 |
+
|
| 184 |
+
# Out-Distribution ์์
|
| 185 |
+
writer.writerow({
|
| 186 |
+
'distribution': 'out_distribution',
|
| 187 |
+
'query': 'ํ๊ตญ์ ์๋๋ ์ด๋์ธ๊ฐ์?',
|
| 188 |
+
'expected_answer': '์์ธ์
๋๋ค.',
|
| 189 |
+
'category': 'general_knowledge',
|
| 190 |
+
'source_doc': '',
|
| 191 |
+
'metadata': '{"difficulty": "easy"}'
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
print(f"๐ CSV ํ
ํ๋ฆฟ ์ ์ฅ: {output_path}")
|
| 195 |
+
print(f" โ ์ด ํ์ผ์ ์์ ํ์ฌ ์ค์ ๋ฐ์ดํฐ๋ฅผ ์ฑ์์ฃผ์ธ์!")
|
| 196 |
+
|
| 197 |
+
def print_summary(self):
|
| 198 |
+
"""๋ฐ์ดํฐ์
์์ฝ ์ถ๋ ฅ"""
|
| 199 |
+
print("\n" + "="*60)
|
| 200 |
+
print("๋ฐ์ดํฐ์
์์ฝ")
|
| 201 |
+
print("="*60)
|
| 202 |
+
print(f"In-Distribution: {len(self.dataset['in_distribution'])}๊ฐ")
|
| 203 |
+
print(f"Out-Distribution: {len(self.dataset['out_distribution'])}๊ฐ")
|
| 204 |
+
print(f"์ด ์ํ: {len(self.dataset['in_distribution']) + len(self.dataset['out_distribution'])}๊ฐ")
|
| 205 |
+
print("="*60 + "\n")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def main():
|
| 209 |
+
parser = argparse.ArgumentParser(description='ํ๊ฐ ๋ฐ์ดํฐ์
์์ฑ')
|
| 210 |
+
parser.add_argument('--mode', choices=['template', 'csv'], default='template',
|
| 211 |
+
help='์์ฑ ๋ชจ๋: template (ํ
ํ๋ฆฟ ์์ฑ) ๋๋ csv (CSV์์ ๋ก๋)')
|
| 212 |
+
parser.add_argument('--input', type=str, help='์
๋ ฅ CSV ํ์ผ ๊ฒฝ๋ก')
|
| 213 |
+
parser.add_argument('--output', type=str, default='data/eval_dataset.json',
|
| 214 |
+
help='์ถ๋ ฅ JSON ํ์ผ ๊ฒฝ๋ก')
|
| 215 |
+
parser.add_argument('--csv-template', type=str, default='data/eval_template.csv',
|
| 216 |
+
help='CSV ํ
ํ๋ฆฟ ์ ์ฅ ๊ฒฝ๋ก')
|
| 217 |
+
|
| 218 |
+
args = parser.parse_args()
|
| 219 |
+
|
| 220 |
+
creator = EvalDatasetCreator()
|
| 221 |
+
|
| 222 |
+
if args.mode == 'template':
|
| 223 |
+
print("๐ ํ
ํ๋ฆฟ ๋ชจ๋")
|
| 224 |
+
creator.create_template_dataset()
|
| 225 |
+
creator.save_json(args.output)
|
| 226 |
+
creator.save_csv_template(args.csv_template)
|
| 227 |
+
|
| 228 |
+
elif args.mode == 'csv':
|
| 229 |
+
if not args.input:
|
| 230 |
+
print("โ CSV ๋ชจ๋์์๋ --input ์ต์
์ด ํ์ํฉ๋๋ค.")
|
| 231 |
+
return
|
| 232 |
+
|
| 233 |
+
print("๐ฅ CSV ๋ชจ๋")
|
| 234 |
+
creator.load_from_csv(args.input)
|
| 235 |
+
creator.save_json(args.output)
|
| 236 |
+
|
| 237 |
+
creator.print_summary()
|
| 238 |
+
|
| 239 |
+
print("\nโ
์๋ฃ!")
|
| 240 |
+
print(f"\n๋ค์ ๋จ๊ณ:")
|
| 241 |
+
print(f"1. {args.csv_template} ํ์ผ์ ์ด์ด์ ์ค์ ๋ฐ์ดํฐ ์์ฑ")
|
| 242 |
+
print(f"2. python create_eval_dataset.py --mode csv --input {args.csv_template} --output {args.output}")
|
| 243 |
+
print(f"3. ์์ฑ๋ {args.output}์ ์คํ์ ์ฌ์ฉ")
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
if __name__ == "__main__":
|
| 247 |
+
main()
|
src/eval_dataset.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, List, Any
|
| 4 |
+
|
| 5 |
+
class EvalDataset:
|
| 6 |
+
def __init__(self, dataset_path: str = "data/eval_dataset.json"):
|
| 7 |
+
self.dataset_path = Path(dataset_path)
|
| 8 |
+
self.data = self._load_dataset()
|
| 9 |
+
|
| 10 |
+
def _load_dataset(self) -> Dict[str, Any]:
|
| 11 |
+
with open(self.dataset_path, 'r', encoding='utf-8') as f:
|
| 12 |
+
return json.load(f)
|
| 13 |
+
|
| 14 |
+
def get_in_distribution(self) -> List[Dict[str, Any]]:
|
| 15 |
+
return self.data.get('in_distribution', [])
|
| 16 |
+
|
| 17 |
+
def get_out_distribution(self) -> List[Dict[str, Any]]:
|
| 18 |
+
return self.data.get('out_distribution', [])
|
| 19 |
+
|
| 20 |
+
def get_all_queries(self) -> Dict[str, List[Dict[str, Any]]]:
|
| 21 |
+
return {
|
| 22 |
+
'in_distribution': self.get_in_distribution(),
|
| 23 |
+
'out_distribution': self.get_out_distribution()
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def print_summary(self):
|
| 27 |
+
print(f"In-Distribution: {len(self.get_in_distribution())}๊ฐ")
|
| 28 |
+
print(f"Out-Distribution: {len(self.get_out_distribution())}๊ฐ")
|
| 29 |
+
|
| 30 |
+
def print_samples(self, n: int = 3):
|
| 31 |
+
print("\n[In-Distribution ์ํ]")
|
| 32 |
+
for item in self.get_in_distribution()[:n]:
|
| 33 |
+
print(f" - {item['query']}")
|
src/generator/generator_gguf.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llama_cpp import Llama
|
| 2 |
+
from typing import Optional, Dict, Any, List
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from src.utils.config import RAGConfig
|
| 8 |
+
from src.router.query_router import QueryRouter
|
| 9 |
+
from src.prompts.dynamic_prompts import PromptManager
|
| 10 |
+
|
| 11 |
+
# ๋ก๊น
์ค์
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GGUFGenerator:
|
| 17 |
+
"""
|
| 18 |
+
GGUF ๊ธฐ๋ฐ Llama-3 ์์ฑ๊ธฐ
|
| 19 |
+
|
| 20 |
+
llama.cpp๋ฅผ ์ฌ์ฉํ์ฌ GGUF ํฌ๋งท ๋ชจ๋ธ์ ๋ก๋ํ๊ณ
|
| 21 |
+
์
์ฐฐ ๊ด๋ จ ์ง์์๋ต์ ์ํํฉ๋๋ค.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model_path: str,
|
| 27 |
+
n_gpu_layers: int = 0,
|
| 28 |
+
n_ctx: int = 8192,
|
| 29 |
+
n_threads: int = 8,
|
| 30 |
+
config = None,
|
| 31 |
+
max_new_tokens: int = 256,
|
| 32 |
+
temperature: float = 0.7,
|
| 33 |
+
top_p: float = 0.9,
|
| 34 |
+
system_prompt: str = "๋น์ ์ RFP(์ ์์์ฒญ์) ๋ถ์ ๋ฐ ์์ฝ ์ ๋ฌธ๊ฐ์
๋๋ค."
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
์์ฑ๊ธฐ ์ด๊ธฐํ
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model_path: GGUF ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก
|
| 41 |
+
n_gpu_layers: GPU์ ์ฌ๋ฆด ๋ ์ด์ด ์ (0 = CPU๋ง, 35 = ์ ์ฒด GPU)
|
| 42 |
+
n_ctx: ์ต๋ ์ปจํ
์คํธ ๊ธธ์ด
|
| 43 |
+
n_threads: CPU ์ค๋ ๋ ์
|
| 44 |
+
max_new_tokens: ์ต๋ ์์ฑ ํ ํฐ ์
|
| 45 |
+
temperature: ์์ฑ ๋ค์์ฑ (0.0~1.0)
|
| 46 |
+
top_p: Nucleus sampling ํ๋ผ๋ฏธํฐ
|
| 47 |
+
system_prompt: ์์คํ
ํ๋กฌํํธ
|
| 48 |
+
"""
|
| 49 |
+
self.config = config or RAGConfig()
|
| 50 |
+
self.model_path = model_path
|
| 51 |
+
self.n_gpu_layers = n_gpu_layers
|
| 52 |
+
self.n_ctx = n_ctx
|
| 53 |
+
self.n_threads = n_threads
|
| 54 |
+
self.max_new_tokens = max_new_tokens
|
| 55 |
+
self.temperature = temperature
|
| 56 |
+
self.top_p = top_p
|
| 57 |
+
self.system_prompt = system_prompt
|
| 58 |
+
|
| 59 |
+
# ๋ชจ๋ธ (๋์ค์ ๋ก๋)
|
| 60 |
+
self.model = None
|
| 61 |
+
|
| 62 |
+
logger.info(f"GGUFGenerator ์ด๊ธฐํ ์๋ฃ")
|
| 63 |
+
|
| 64 |
+
def load_model(self) -> None:
|
| 65 |
+
"""
|
| 66 |
+
GGUF ๋ชจ๋ธ ๋ก๋
|
| 67 |
+
|
| 68 |
+
๋ก์ง:
|
| 69 |
+
1. USE_MODEL_HUB ํ์ธ
|
| 70 |
+
2-A. True โ Hugging Face Hub์์ ๋ค์ด๋ก๋
|
| 71 |
+
2-B. False โ ๋ก์ปฌ ํ์ผ ์ฌ์ฉ
|
| 72 |
+
3. ๋ชจ๋ธ ๋ก๋
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
# ์ค๋ณต ๋ก๋ ๋ฐฉ์ง
|
| 76 |
+
if self.model is not None:
|
| 77 |
+
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋์ด ์์ต๋๋ค.")
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
# Config์์ USE_MODEL_HUB ํ์ธ (์์ผ๋ฉด True ๊ธฐ๋ณธ๊ฐ)
|
| 82 |
+
use_model_hub = getattr(self.config, 'USE_MODEL_HUB', True)
|
| 83 |
+
|
| 84 |
+
# Model Hub ์ฌ์ฉ ์ฌ๋ถ์ ๋ฐ๋ผ ๊ฒฝ๋ก ๊ฒฐ์
|
| 85 |
+
if use_model_hub:
|
| 86 |
+
# === Model Hub์์ ๋ค์ด๋ก๋ ===
|
| 87 |
+
model_hub_repo = getattr(self.config, 'MODEL_HUB_REPO', 'beomi/Llama-3-Open-Ko-8B-gguf')
|
| 88 |
+
model_hub_filename = getattr(self.config, 'MODEL_HUB_FILENAME', 'ggml-model-Q4_K_M.gguf')
|
| 89 |
+
model_cache_dir = getattr(self.config, 'MODEL_CACHE_DIR', '.cache/models')
|
| 90 |
+
|
| 91 |
+
logger.info(f"๐ฅ Model Hub์์ ๋ค์ด๋ก๋: {model_hub_repo}")
|
| 92 |
+
|
| 93 |
+
from huggingface_hub import hf_hub_download
|
| 94 |
+
|
| 95 |
+
model_path = hf_hub_download(
|
| 96 |
+
repo_id=model_hub_repo,
|
| 97 |
+
filename=model_hub_filename,
|
| 98 |
+
cache_dir=model_cache_dir,
|
| 99 |
+
local_dir=model_cache_dir,
|
| 100 |
+
local_dir_use_symlinks=False # ์ฌ๋ณผ๋ฆญ ๋งํฌ ๋์ ์ค์ ๋ณต์ฌ
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
logger.info(f"โ
๋ค์ด๋ก๋ ์๋ฃ: {model_path}")
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
# === ๋ก์ปฌ ํ์ผ ์ฌ์ฉ ===
|
| 107 |
+
model_path = self.model_path # ์์ฑ์์์ ๋ฐ์ ๊ฒฝ๋ก ์ฌ์ฉ
|
| 108 |
+
|
| 109 |
+
if not os.path.exists(model_path):
|
| 110 |
+
raise FileNotFoundError(
|
| 111 |
+
f"โ ๋ก์ปฌ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {model_path}\n"
|
| 112 |
+
f" USE_MODEL_HUB=true๋ก ์ค์ ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ์ ์ค๋นํ์ธ์."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
logger.info(f"๐ ๋ก์ปฌ ๋ชจ๋ธ ์ฌ์ฉ: {model_path}")
|
| 116 |
+
|
| 117 |
+
# === ๊ณตํต: ๋ชจ๋ธ ๋ก๋ ===
|
| 118 |
+
logger.info(f"๐ GGUF ๋ชจ๋ธ ๋ก๋ ์ค...")
|
| 119 |
+
logger.info(f" GPU ๋ ์ด์ด: {self.n_gpu_layers}")
|
| 120 |
+
logger.info(f" ์ปจํ
์คํธ: {self.n_ctx}")
|
| 121 |
+
|
| 122 |
+
self.model = Llama(
|
| 123 |
+
model_path=model_path,
|
| 124 |
+
n_gpu_layers=self.n_gpu_layers,
|
| 125 |
+
n_ctx=self.n_ctx,
|
| 126 |
+
n_threads=self.n_threads,
|
| 127 |
+
verbose=True, # โ
๋๋ฒ๊ทธ ๋ก๊ทธ ํ์ฑํ
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# โ
์ค์ ์ ์ฉ๋ n_ctx ํ์ธ
|
| 131 |
+
actual_n_ctx = self.model.n_ctx()
|
| 132 |
+
logger.info("โ
GGUF ๋ชจ๋ธ ๋ก๋ ์๋ฃ!")
|
| 133 |
+
logger.info(f" - ์ค์ ํ n_ctx: {self.n_ctx}")
|
| 134 |
+
logger.info(f" - ์ค์ n_ctx: {actual_n_ctx}")
|
| 135 |
+
|
| 136 |
+
if actual_n_ctx < self.n_ctx:
|
| 137 |
+
logger.warning(f"โ ๏ธ n_ctx๊ฐ ์์๋ณด๋ค ์์ต๋๋ค: {actual_n_ctx} < {self.n_ctx}")
|
| 138 |
+
logger.warning(f" ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ์ผ ์ ์์ต๋๋ค. n_gpu_layers๋ฅผ ์ค์ฌ๋ณด์ธ์.")
|
| 139 |
+
|
| 140 |
+
except FileNotFoundError as e:
|
| 141 |
+
logger.error(f"โ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {e}")
|
| 142 |
+
raise
|
| 143 |
+
except Exception as e:
|
| 144 |
+
logger.error(f"โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}")
|
| 145 |
+
raise RuntimeError(f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 146 |
+
|
| 147 |
+
def format_prompt(
|
| 148 |
+
self,
|
| 149 |
+
question: str,
|
| 150 |
+
context: Optional[str] = None,
|
| 151 |
+
system_prompt: Optional[str] = None
|
| 152 |
+
) -> str:
|
| 153 |
+
"""
|
| 154 |
+
GGUF ๋ชจ๋ธ์ฉ ๊ฐ๋จํ ํ๋กฌํํธ ํฌ๋งทํ
|
| 155 |
+
|
| 156 |
+
Llama-3 ํน์ ํ ํฐ ๋์ ์์ ํ
์คํธ ๊ธฐ๋ฐ ํ
ํ๋ฆฟ ์ฌ์ฉ
|
| 157 |
+
"""
|
| 158 |
+
# ์์คํ
ํ๋กฌํํธ ์ค์
|
| 159 |
+
if system_prompt is None:
|
| 160 |
+
system_prompt = self.system_prompt
|
| 161 |
+
|
| 162 |
+
# ์ปจํ
์คํธ ํฌํจ ์ฌ๋ถ
|
| 163 |
+
if context is not None:
|
| 164 |
+
user_message = f"์ฐธ๊ณ ๋ฌธ์:\n{context}\n\n์ง๋ฌธ: {question}"
|
| 165 |
+
else:
|
| 166 |
+
user_message = question
|
| 167 |
+
|
| 168 |
+
# ๊ฐ๋จํ ํ๊ตญ์ด ํ
ํ๋ฆฟ (ํน์ ํ ํฐ ์์)
|
| 169 |
+
formatted_prompt = f"""### ์์คํ
|
| 170 |
+
{system_prompt}
|
| 171 |
+
|
| 172 |
+
### ์ฌ์ฉ์
|
| 173 |
+
{user_message}
|
| 174 |
+
|
| 175 |
+
### ๋ต๋ณ
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
return formatted_prompt
|
| 179 |
+
|
| 180 |
+
def generate(
|
| 181 |
+
self,
|
| 182 |
+
prompt: str,
|
| 183 |
+
max_new_tokens: Optional[int] = None,
|
| 184 |
+
temperature: Optional[float] = None,
|
| 185 |
+
top_p: Optional[float] = None,
|
| 186 |
+
) -> str:
|
| 187 |
+
"""
|
| 188 |
+
ํ๋กฌํํธ๋ฅผ ์
๋ ฅ๋ฐ์ ์๋ต ์์ฑ
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
prompt: ํฌ๋งท๋ ํ๋กฌํํธ
|
| 192 |
+
max_new_tokens: ์ต๋ ์์ฑ ํ ํฐ ์
|
| 193 |
+
temperature: ์์ฑ ๋ค์์ฑ
|
| 194 |
+
top_p: Nucleus sampling
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
์์ฑ๋ ์๋ต ํ
์คํธ
|
| 198 |
+
|
| 199 |
+
Raises:
|
| 200 |
+
RuntimeError: ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์ ๊ฒฝ์ฐ
|
| 201 |
+
"""
|
| 202 |
+
# ๋ชจ๋ธ ๋ก๋ ํ์ธ
|
| 203 |
+
if self.model is None:
|
| 204 |
+
raise RuntimeError(
|
| 205 |
+
"๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. load_model()์ ๋จผ์ ํธ์ถํ์ธ์."
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# ํ๋ผ๋ฏธํฐ ์ค์
|
| 209 |
+
if max_new_tokens is None:
|
| 210 |
+
max_new_tokens = self.max_new_tokens
|
| 211 |
+
if temperature is None:
|
| 212 |
+
temperature = self.temperature
|
| 213 |
+
if top_p is None:
|
| 214 |
+
top_p = self.top_p
|
| 215 |
+
|
| 216 |
+
try:
|
| 217 |
+
logger.info(f"๐ ์์ฑ ์์ (max_tokens={max_new_tokens}, temp={temperature})")
|
| 218 |
+
start_time = time.time()
|
| 219 |
+
|
| 220 |
+
# ์์ฑ
|
| 221 |
+
output = self.model(
|
| 222 |
+
prompt,
|
| 223 |
+
max_tokens=max_new_tokens,
|
| 224 |
+
temperature=temperature,
|
| 225 |
+
top_p=top_p,
|
| 226 |
+
echo=False, # ํ๋กฌํํธ ๋ฐ๋ณต ์ ํจ
|
| 227 |
+
stop=[
|
| 228 |
+
# ๊ตฌ๋ถ์
|
| 229 |
+
"###", "\n\n###",
|
| 230 |
+
"### ์ฌ์ฉ์", "\n์ฌ์ฉ์:",
|
| 231 |
+
"</s>",
|
| 232 |
+
# ๋ฉํ ํ
์คํธ ์ฐจ๋จ
|
| 233 |
+
"ํ๊ตญ์ด ๋ต๋ณ", "ํ๊ตญ์ด๋ก ๋ต๋ณ", "์ง์นจ:",
|
| 234 |
+
"๋ฌธ์ฅ", "(๋ฌธ์ฅ",
|
| 235 |
+
# โ
์ง๋ฌธ ํจํด ์ฐจ๋จ (๋ต๋ณ ํ ์ง๋ฌธ ์์ฑ ๋ฐฉ์ง)
|
| 236 |
+
"\n\n", # ๋จ๋ฝ ๊ตฌ๋ถ
|
| 237 |
+
"?", # ์ง๋ฌธ ๊ธฐํธ
|
| 238 |
+
"์?", "๊น?", "๋์?", "์ต๋๊น?" # ์ง๋ฌธ ์ด๋ฏธ
|
| 239 |
+
],
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
elapsed = time.time() - start_time
|
| 243 |
+
logger.info(f"โ
์์ฑ ์๋ฃ: {elapsed:.2f}์ด")
|
| 244 |
+
|
| 245 |
+
# ์๋ต ์ถ์ถ
|
| 246 |
+
response = output['choices'][0]['text'].strip()
|
| 247 |
+
|
| 248 |
+
logger.info(f"๐ ์๋ต ๊ธธ์ด: {len(response)} ๊ธ์")
|
| 249 |
+
return response
|
| 250 |
+
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.error(f"โ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 253 |
+
raise RuntimeError(f"ํ
์คํธ ์์ฑ ์คํจ: {e}")
|
| 254 |
+
|
| 255 |
+
def chat(
|
| 256 |
+
self,
|
| 257 |
+
question: str,
|
| 258 |
+
context: Optional[str] = None,
|
| 259 |
+
system_prompt=None,
|
| 260 |
+
**kwargs
|
| 261 |
+
) -> str:
|
| 262 |
+
"""
|
| 263 |
+
์ง๋ฌธ์ ๋ํ ์๋ต ์์ฑ (ํตํฉ ๋ฉ์๋)
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
question: ์ฌ์ฉ์ ์ง๋ฌธ
|
| 267 |
+
context: ์ ํ์ ์ปจํ
์คํธ
|
| 268 |
+
system_prompt: ์ ํ์ ์์คํ
ํ๋กฌํํธ
|
| 269 |
+
**kwargs: generate() ๋ฉ์๋์ ์ ๋ฌ๋ ์ถ๊ฐ ํ๋ผ๋ฏธํฐ
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
์์ฑ๋ ์๋ต
|
| 273 |
+
"""
|
| 274 |
+
# ํ๋กฌํํธ ํฌ๋งทํ
|
| 275 |
+
prompt = self.format_prompt(
|
| 276 |
+
question=question,
|
| 277 |
+
context=context,
|
| 278 |
+
system_prompt=system_prompt
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# ์๋ต ์์ฑ
|
| 282 |
+
response = self.generate(prompt, **kwargs)
|
| 283 |
+
|
| 284 |
+
return response
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class GGUFRAGPipeline:
|
| 288 |
+
"""
|
| 289 |
+
GGUF ์์ฑ๊ธฐ + RAG ํตํฉ ํ์ดํ๋ผ์ธ
|
| 290 |
+
|
| 291 |
+
chatbot_app.py์ ํธํ๋๋ ์ธํฐํ์ด์ค ์ ๊ณต
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
def __init__(
|
| 295 |
+
self,
|
| 296 |
+
config=None,
|
| 297 |
+
model: str = None, # ํธํ์ฑ์ฉ (์ฌ์ฉ ์ ํจ)
|
| 298 |
+
top_k: int = None,
|
| 299 |
+
# GPU ์ค์ (์ ํ์ , config ์ค๋ฒ๋ผ์ด๋)
|
| 300 |
+
n_gpu_layers: int = None,
|
| 301 |
+
n_ctx: int = None,
|
| 302 |
+
n_threads: int = None,
|
| 303 |
+
max_new_tokens: int = None,
|
| 304 |
+
temperature: float = None,
|
| 305 |
+
top_p: float = None,
|
| 306 |
+
search_mode: str = None,
|
| 307 |
+
alpha: float = None
|
| 308 |
+
):
|
| 309 |
+
"""
|
| 310 |
+
์ด๊ธฐํ
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
config: RAGConfig ๊ฐ์ฒด
|
| 314 |
+
model: ๋ชจ๋ธ ์ด๋ฆ (์ฌ์ฉ ์ ํจ, ํธํ์ฑ์ฉ)
|
| 315 |
+
top_k: ๊ธฐ๋ณธ ๊ฒ์ ๋ฌธ์ ์
|
| 316 |
+
n_gpu_layers: GPU ๋ ์ด์ด ์ (config ์ค๋ฒ๋ผ์ด๋)
|
| 317 |
+
n_ctx: ์ปจํ
์คํธ ๊ธธ์ด (config ์ค๋ฒ๋ผ์ด๋)
|
| 318 |
+
n_threads: CPU ์ค๋ ๋ ์ (config ์ค๋ฒ๋ผ์ด๋)
|
| 319 |
+
max_new_tokens: ์ต๋ ์์ฑ ํ ํฐ (config ์ค๋ฒ๋ผ์ด๋)
|
| 320 |
+
temperature: ์์ฑ ๋ค์์ฑ (config ์ค๋ฒ๋ผ์ด๋)
|
| 321 |
+
top_p: Nucleus sampling (config ์ค๋ฒ๋ผ์ด๋)
|
| 322 |
+
search_mode: ๊ฒ์ ๋ชจ๋
|
| 323 |
+
alpha: ์๋ฒ ๋ฉ ๊ฐ์ค์น
|
| 324 |
+
"""
|
| 325 |
+
self.config = config or RAGConfig()
|
| 326 |
+
|
| 327 |
+
# Config์์ ๊ธฐ๋ณธ๊ฐ ๊ฐ์ ธ์ค๊ธฐ (์์ผ๋ฉด fallback)
|
| 328 |
+
self.top_k = top_k or getattr(self.config, 'DEFAULT_TOP_K', 10)
|
| 329 |
+
|
| 330 |
+
# ๊ฒ์ ์ค์
|
| 331 |
+
self.search_mode = search_mode or getattr(self.config, 'DEFAULT_SEARCH_MODE', 'hybrid_rerank')
|
| 332 |
+
self.alpha = alpha if alpha is not None else getattr(self.config, 'DEFAULT_ALPHA', 0.5)
|
| 333 |
+
|
| 334 |
+
# Retriever ์ด๊ธฐํ (RAGRetriever ์ฌ์ฉ)
|
| 335 |
+
logger.info("RAGRetriever ์ด๊ธฐํ ์ค...")
|
| 336 |
+
from src.retriever.retriever import RAGRetriever
|
| 337 |
+
self.retriever = RAGRetriever(config=self.config)
|
| 338 |
+
|
| 339 |
+
# GGUF ์ค์ (ํ๋ผ๋ฏธํฐ๊ฐ ์ฃผ์ด์ง๋ฉด config ์ค๋ฒ๋ผ์ด๋, ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ)
|
| 340 |
+
gguf_n_gpu_layers = n_gpu_layers if n_gpu_layers is not None else getattr(self.config, 'GGUF_N_GPU_LAYERS', 35)
|
| 341 |
+
gguf_n_ctx = n_ctx if n_ctx is not None else getattr(self.config, 'GGUF_N_CTX', 2048)
|
| 342 |
+
gguf_n_threads = n_threads if n_threads is not None else getattr(self.config, 'GGUF_N_THREADS', 4)
|
| 343 |
+
gguf_max_new_tokens = max_new_tokens if max_new_tokens is not None else getattr(self.config, 'GGUF_MAX_NEW_TOKENS', 512)
|
| 344 |
+
gguf_temperature = temperature if temperature is not None else getattr(self.config, 'GGUF_TEMPERATURE', 0.7)
|
| 345 |
+
gguf_top_p = top_p if top_p is not None else getattr(self.config, 'GGUF_TOP_P', 0.9)
|
| 346 |
+
|
| 347 |
+
# ๋ชจ๋ธ ๊ฒฝ๋ก (fallback)
|
| 348 |
+
gguf_model_path = getattr(self.config, 'GGUF_MODEL_PATH', '.cache/models/llama-3-ko-8b.gguf')
|
| 349 |
+
|
| 350 |
+
# ์์คํ
ํ๋กฌํํธ (fallback)
|
| 351 |
+
system_prompt = getattr(self.config, 'SYSTEM_PROMPT', '๋น์ ์ ํ๊ตญ ๊ณต๊ณต๊ธฐ๊ด ์ฌ์
์ ์์ ๋ถ์ ์ ๋ฌธ๊ฐ์
๋๋ค.')
|
| 352 |
+
|
| 353 |
+
# GGUFGenerator ์ด๊ธฐํ
|
| 354 |
+
logger.info("GGUFGenerator ์ด๊ธฐํ ์ค...")
|
| 355 |
+
logger.info(f" GPU ๋ ์ด์ด: {gguf_n_gpu_layers}")
|
| 356 |
+
logger.info(f" ์ปจํ
์คํธ: {gguf_n_ctx}")
|
| 357 |
+
logger.info(f" ์ค๋ ๋: {gguf_n_threads}")
|
| 358 |
+
logger.info(f" ๋ชจ๋ธ ๊ฒฝ๋ก: {gguf_model_path}")
|
| 359 |
+
|
| 360 |
+
self.generator = GGUFGenerator(
|
| 361 |
+
model_path=gguf_model_path,
|
| 362 |
+
n_gpu_layers=gguf_n_gpu_layers,
|
| 363 |
+
n_ctx=gguf_n_ctx,
|
| 364 |
+
n_threads=gguf_n_threads,
|
| 365 |
+
config=self.config,
|
| 366 |
+
max_new_tokens=gguf_max_new_tokens,
|
| 367 |
+
temperature=gguf_temperature,
|
| 368 |
+
top_p=gguf_top_p,
|
| 369 |
+
system_prompt=system_prompt
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# ๋ชจ๋ธ ๋ก๋ (์๊ฐ ์์)
|
| 373 |
+
logger.info("GGUF ๋ชจ๋ธ ๋ก๋ ์ค...")
|
| 374 |
+
self.generator.load_model()
|
| 375 |
+
|
| 376 |
+
# ๋ํ ํ์คํ ๋ฆฌ
|
| 377 |
+
self.chat_history: List[Dict] = []
|
| 378 |
+
|
| 379 |
+
# ๋ง์ง๋ง ๊ฒ์ ๊ฒฐ๊ณผ ์ ์ฅ (sources ๋ฐํ์ฉ)
|
| 380 |
+
self._last_retrieved_docs = []
|
| 381 |
+
|
| 382 |
+
logger.info("โ
GGUFRAGPipeline ์ด๊ธฐํ ์๋ฃ")
|
| 383 |
+
logger.info(f" - ๊ฒ์ ๋ชจ๋: {self.search_mode}")
|
| 384 |
+
logger.info(f" - ๊ธฐ๋ณธ top_k: {self.top_k}")
|
| 385 |
+
|
| 386 |
+
def _retrieve_and_format(self, query: str) -> str:
|
| 387 |
+
"""๊ฒ์ ์ํ ๋ฐ ์ปจํ
์คํธ ํฌ๋งทํ
"""
|
| 388 |
+
# ๊ฒ์ ๋ชจ๋์ ๋ฐ๋ผ ๋ฌธ์ ๊ฒ์ (RAGRetriever ๋ฉ์๋ ์ฌ์ฉ)
|
| 389 |
+
if self.search_mode == "embedding":
|
| 390 |
+
docs = self.retriever.search(query, top_k=self.top_k)
|
| 391 |
+
elif self.search_mode == "embedding_rerank":
|
| 392 |
+
docs = self.retriever.search_with_rerank(query, top_k=self.top_k)
|
| 393 |
+
elif self.search_mode == "hybrid":
|
| 394 |
+
docs = self.retriever.hybrid_search(
|
| 395 |
+
query, top_k=self.top_k, alpha=self.alpha
|
| 396 |
+
)
|
| 397 |
+
elif self.search_mode == "hybrid_rerank":
|
| 398 |
+
docs = self.retriever.hybrid_search_with_rerank(
|
| 399 |
+
query, top_k=self.top_k, alpha=self.alpha
|
| 400 |
+
)
|
| 401 |
+
else:
|
| 402 |
+
docs = self.retriever.search(query, top_k=self.top_k)
|
| 403 |
+
|
| 404 |
+
# ๋ง์ง๋ง ๊ฒ์ ๊ฒฐ๊ณผ ์ ์ฅ
|
| 405 |
+
self._last_retrieved_docs = docs
|
| 406 |
+
|
| 407 |
+
# ์ปจํ
์คํธ ํฌ๋งทํ
|
| 408 |
+
return self._format_context(docs)
|
| 409 |
+
|
| 410 |
+
def _format_context(self, retrieved_docs: list) -> str:
|
| 411 |
+
"""
|
| 412 |
+
๊ฒ์๋ ๋ฌธ์๋ฅผ ์ปจํ
์คํธ๋ก ๋ณํ
|
| 413 |
+
|
| 414 |
+
์ปจํ
์คํธ๊ฐ ๋๋ฌด ๊ธธ๋ฉด ์๋์ผ๋ก ์ค์ (ํ ํฐ ์ ํ ๋์)
|
| 415 |
+
"""
|
| 416 |
+
if not retrieved_docs:
|
| 417 |
+
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."
|
| 418 |
+
|
| 419 |
+
context_parts = []
|
| 420 |
+
max_context_chars = 8000 # ๋๋ต 2000 ํ ํฐ ์ ๋ (์ฌ์ ์๊ฒ)
|
| 421 |
+
|
| 422 |
+
current_length = 0
|
| 423 |
+
for i, doc in enumerate(retrieved_docs, 1):
|
| 424 |
+
doc_text = f"[๋ฌธ์ {i}]\n{doc['content']}\n"
|
| 425 |
+
doc_length = len(doc_text)
|
| 426 |
+
|
| 427 |
+
# ์ปจํ
์คํธ ๊ธธ์ด ์ฒดํฌ
|
| 428 |
+
if current_length + doc_length > max_context_chars:
|
| 429 |
+
logger.warning(f"โ ๏ธ ์ปจํ
์คํธ ๊ธธ์ด ์ ํ: {i-1}๊ฐ ๋ฌธ์๋ง ์ฌ์ฉ (์ต๋ {max_context_chars}์)")
|
| 430 |
+
break
|
| 431 |
+
|
| 432 |
+
context_parts.append(doc_text)
|
| 433 |
+
current_length += doc_length
|
| 434 |
+
|
| 435 |
+
return "\n".join(context_parts)
|
| 436 |
+
|
| 437 |
+
def _format_sources(self, retrieved_docs: list) -> list:
|
| 438 |
+
"""๊ฒ์๋ ๋ฌธ์๋ฅผ sources ํ์์ผ๋ก ๋ณํ"""
|
| 439 |
+
sources = []
|
| 440 |
+
for doc in retrieved_docs:
|
| 441 |
+
source_info = {
|
| 442 |
+
'content': doc['content'],
|
| 443 |
+
'metadata': doc['metadata'],
|
| 444 |
+
'filename': doc.get('filename', 'N/A'),
|
| 445 |
+
'organization': doc.get('organization', 'N/A')
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
# ๊ฒ์ ๋ชจ๋์ ๋ฐ๋ผ ์ ์ ํ๋๊ฐ ๋ค๋ฆ
|
| 449 |
+
if 'rerank_score' in doc:
|
| 450 |
+
source_info['score'] = doc['rerank_score']
|
| 451 |
+
source_info['score_type'] = 'rerank'
|
| 452 |
+
elif 'hybrid_score' in doc:
|
| 453 |
+
source_info['score'] = doc['hybrid_score']
|
| 454 |
+
source_info['score_type'] = 'hybrid'
|
| 455 |
+
elif 'relevance_score' in doc:
|
| 456 |
+
source_info['score'] = doc['relevance_score']
|
| 457 |
+
source_info['score_type'] = 'embedding'
|
| 458 |
+
else:
|
| 459 |
+
source_info['score'] = 0
|
| 460 |
+
source_info['score_type'] = 'unknown'
|
| 461 |
+
|
| 462 |
+
sources.append(source_info)
|
| 463 |
+
|
| 464 |
+
return sources
|
| 465 |
+
|
| 466 |
+
def _estimate_usage(self, query: str, answer: str) -> dict:
|
| 467 |
+
"""ํ ํฐ ์ฌ์ฉ๋ ์ถ์ """
|
| 468 |
+
# ๊ฐ๋จํ ๋จ์ด ์ ๊ธฐ๋ฐ ์ถ์
|
| 469 |
+
prompt_tokens = len(query.split()) * 2
|
| 470 |
+
completion_tokens = len(answer.split()) * 2
|
| 471 |
+
|
| 472 |
+
return {
|
| 473 |
+
'total_tokens': prompt_tokens + completion_tokens,
|
| 474 |
+
'prompt_tokens': prompt_tokens,
|
| 475 |
+
'completion_tokens': completion_tokens
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
def generate_answer(
|
| 479 |
+
self,
|
| 480 |
+
query: str,
|
| 481 |
+
top_k: int = None,
|
| 482 |
+
search_mode: str = None,
|
| 483 |
+
alpha: float = None
|
| 484 |
+
) -> dict:
|
| 485 |
+
"""
|
| 486 |
+
๋ต๋ณ ์์ฑ (chatbot_app.py ํธํ ๋ฉ์ธ ๋ฉ์๋)
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
query: ์ง๋ฌธ
|
| 490 |
+
top_k: ๊ฒ์ํ ๋ฌธ์ ์
|
| 491 |
+
search_mode: ๊ฒ์ ๋ชจ๋
|
| 492 |
+
alpha: ์๋ฒ ๋ฉ ๊ฐ์ค์น
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
dict: answer, sources, search_mode, usage, elapsed_time, used_retrieval
|
| 496 |
+
"""
|
| 497 |
+
try:
|
| 498 |
+
start_time = time.time()
|
| 499 |
+
|
| 500 |
+
# ํ๋ผ๋ฏธํฐ ์ค์ (๊ฒ์ ์ ์ ๋จผ์ ์ค์ )
|
| 501 |
+
if top_k is not None:
|
| 502 |
+
self.top_k = top_k
|
| 503 |
+
if search_mode is not None:
|
| 504 |
+
self.search_mode = search_mode
|
| 505 |
+
if alpha is not None:
|
| 506 |
+
self.alpha = alpha
|
| 507 |
+
|
| 508 |
+
# ===== Router๋ก ๊ฒ์ ์ฌ๋ถ ๊ฒฐ์ =====
|
| 509 |
+
router = QueryRouter()
|
| 510 |
+
classification = router.classify(query)
|
| 511 |
+
query_type = classification['type'] # 'greeting'/'thanks'/'document'/'out_of_scope'
|
| 512 |
+
|
| 513 |
+
logger.info(f"๐ ๋ถ๋ฅ: {query_type} "
|
| 514 |
+
f"(์ ๋ขฐ๋: {classification['confidence']:.2f})")
|
| 515 |
+
|
| 516 |
+
# 2. ํ์
๋ณ ์ฒ๋ฆฌ
|
| 517 |
+
if query_type in ['greeting', 'thanks', 'out_of_scope']:
|
| 518 |
+
# ๊ฒ์ ์คํต
|
| 519 |
+
context = None
|
| 520 |
+
used_retrieval = False
|
| 521 |
+
self._last_retrieved_docs = []
|
| 522 |
+
|
| 523 |
+
# ๋์ ํ๋กฌํํธ ์ ํ (GGUF์ฉ)
|
| 524 |
+
system_prompt = PromptManager.get_prompt(query_type, model_type="gguf")
|
| 525 |
+
logger.info(f"โญ๏ธ RAG ์คํต: {query_type}")
|
| 526 |
+
|
| 527 |
+
elif query_type == 'document':
|
| 528 |
+
# RAG ์ํ
|
| 529 |
+
context = self._retrieve_and_format(query)
|
| 530 |
+
used_retrieval = True
|
| 531 |
+
|
| 532 |
+
# ๋์ ํ๋กฌํํธ (GGUF์ฉ, context ํฌํจ)
|
| 533 |
+
system_prompt = PromptManager.get_prompt('document', model_type="gguf")
|
| 534 |
+
logger.info(f"๐ RAG ์ํ: {len(self._last_retrieved_docs)}๊ฐ ๋ฌธ์")
|
| 535 |
+
|
| 536 |
+
# 3. ๋ต๋ณ ์์ฑ (system_prompt ์ ๋ฌ)
|
| 537 |
+
answer = self.generator.chat(
|
| 538 |
+
question=query,
|
| 539 |
+
context=context,
|
| 540 |
+
system_prompt=system_prompt
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
elapsed_time = time.time() - start_time
|
| 544 |
+
|
| 545 |
+
# ๋ํ ํ์คํ ๋ฆฌ์ ์ถ๊ฐ
|
| 546 |
+
self.chat_history.append({"role": "user", "content": query})
|
| 547 |
+
self.chat_history.append({"role": "assistant", "content": answer})
|
| 548 |
+
|
| 549 |
+
# ๊ฒฐ๊ณผ ๋ฐํ (RAGPipeline๊ณผ ๋์ผ ํ์)
|
| 550 |
+
return {
|
| 551 |
+
'answer': answer,
|
| 552 |
+
'sources': self._format_sources(self._last_retrieved_docs),
|
| 553 |
+
'used_retrieval': used_retrieval,
|
| 554 |
+
'query_type': query_type,
|
| 555 |
+
'search_mode': self.search_mode if used_retrieval else 'direct',
|
| 556 |
+
'routing_info': classification,
|
| 557 |
+
'elapsed_time': elapsed_time,
|
| 558 |
+
'usage': self._estimate_usage(query, answer)
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
except Exception as e:
|
| 562 |
+
logger.error(f"โ ๋ต๋ณ ์์ฑ ์คํจ: {e}")
|
| 563 |
+
import traceback
|
| 564 |
+
traceback.print_exc()
|
| 565 |
+
raise RuntimeError(f"๋ต๋ณ ์์ฑ ์คํจ: {str(e)}") from e
|
| 566 |
+
|
| 567 |
+
def chat(self, query: str) -> str:
|
| 568 |
+
"""๊ฐ๋จํ ๋ํ ์ธํฐํ์ด์ค"""
|
| 569 |
+
result = self.generate_answer(query)
|
| 570 |
+
return result['answer']
|
| 571 |
+
|
| 572 |
+
def clear_history(self):
|
| 573 |
+
"""๋ํ ํ์คํ ๋ฆฌ ์ด๊ธฐํ"""
|
| 574 |
+
self.chat_history = []
|
| 575 |
+
logger.info("๐๏ธ ๋ํ ํ์คํ ๋ฆฌ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค.")
|
| 576 |
+
|
| 577 |
+
def get_history(self) -> List[Dict]:
|
| 578 |
+
"""๋ํ ํ์คํ ๋ฆฌ ๋ฐํ"""
|
| 579 |
+
return self.chat_history.copy()
|
| 580 |
+
|
| 581 |
+
def set_search_config(
|
| 582 |
+
self,
|
| 583 |
+
search_mode: str = None,
|
| 584 |
+
top_k: int = None,
|
| 585 |
+
alpha: float = None
|
| 586 |
+
):
|
| 587 |
+
"""๊ฒ์ ์ค์ ๋ณ๊ฒฝ"""
|
| 588 |
+
if search_mode is not None:
|
| 589 |
+
self.search_mode = search_mode
|
| 590 |
+
if top_k is not None:
|
| 591 |
+
self.top_k = top_k
|
| 592 |
+
if alpha is not None:
|
| 593 |
+
self.alpha = alpha
|
| 594 |
+
|
| 595 |
+
logger.info(
|
| 596 |
+
f"๐ง ๊ฒ์ ์ค์ ๋ณ๊ฒฝ: mode={self.search_mode}, "
|
| 597 |
+
f"top_k={self.top_k}, alpha={self.alpha}"
|
| 598 |
+
)
|
src/generator/generator_gguf_base.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llama_cpp import Llama
|
| 2 |
+
from typing import Optional, Dict, Any, List
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from src.utils.config import RAGConfig
|
| 8 |
+
from src.router.query_router import QueryRouter
|
| 9 |
+
from src.prompts.dynamic_prompts import PromptManager
|
| 10 |
+
|
| 11 |
+
# ๋ก๊น
์ค์
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GGUFGenerator:
|
| 17 |
+
"""
|
| 18 |
+
GGUF ๊ธฐ๋ฐ Llama-3 ์์ฑ๊ธฐ
|
| 19 |
+
|
| 20 |
+
llama.cpp๋ฅผ ์ฌ์ฉํ์ฌ GGUF ํฌ๋งท ๋ชจ๋ธ์ ๋ก๋ํ๊ณ
|
| 21 |
+
์
์ฐฐ ๊ด๋ จ ์ง์์๋ต์ ์ํํฉ๋๋ค.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model_path: str,
|
| 27 |
+
n_gpu_layers: int = 0,
|
| 28 |
+
n_ctx: int = 8192,
|
| 29 |
+
n_threads: int = 8,
|
| 30 |
+
config = None,
|
| 31 |
+
max_new_tokens: int = 256,
|
| 32 |
+
temperature: float = 0.7,
|
| 33 |
+
top_p: float = 0.9,
|
| 34 |
+
system_prompt: str = "๋น์ ์ RFP(์ ์์์ฒญ์) ๋ถ์ ๋ฐ ์์ฝ ์ ๋ฌธ๊ฐ์
๋๋ค."
|
| 35 |
+
):
|
| 36 |
+
"""์์ฑ๊ธฐ ์ด๊ธฐํ"""
|
| 37 |
+
self.config = config or RAGConfig()
|
| 38 |
+
self.model_path = model_path
|
| 39 |
+
self.n_gpu_layers = n_gpu_layers
|
| 40 |
+
self.n_ctx = n_ctx
|
| 41 |
+
self.n_threads = n_threads
|
| 42 |
+
self.max_new_tokens = max_new_tokens
|
| 43 |
+
self.temperature = temperature
|
| 44 |
+
self.top_p = top_p
|
| 45 |
+
self.system_prompt = system_prompt
|
| 46 |
+
|
| 47 |
+
# ๋ชจ๋ธ (๋์ค์ ๋ก๋)
|
| 48 |
+
self.model = None
|
| 49 |
+
|
| 50 |
+
logger.info(f"GGUFGenerator ์ด๊ธฐํ ์๋ฃ (Base ๋ชจ๋ธ)")
|
| 51 |
+
|
| 52 |
+
def load_model(self) -> None:
|
| 53 |
+
"""
|
| 54 |
+
GGUF ๋ชจ๋ธ ๋ก๋
|
| 55 |
+
|
| 56 |
+
โ
Base ๋ชจ๋ธ ์ฌ์ฉ: Config์์ BASE_MODEL_HUB_REPO ๊ฐ์ ธ์ค๊ธฐ
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
# ์ค๋ณต ๋ก๋ ๋ฐฉ์ง
|
| 60 |
+
if self.model is not None:
|
| 61 |
+
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋์ด ์์ต๋๋ค.")
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
# Config์์ USE_MODEL_HUB ํ์ธ
|
| 66 |
+
use_model_hub = getattr(self.config, 'USE_MODEL_HUB', True)
|
| 67 |
+
|
| 68 |
+
# Model Hub ์ฌ์ฉ ์ฌ๋ถ์ ๋ฐ๋ผ ๊ฒฝ๋ก ๊ฒฐ์
|
| 69 |
+
if use_model_hub:
|
| 70 |
+
# === Model Hub์์ ๋ค์ด๋ก๋ ===
|
| 71 |
+
# โ
Config์์ Base ๋ชจ๋ธ ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ
|
| 72 |
+
base_model_repo = getattr(
|
| 73 |
+
self.config,
|
| 74 |
+
'BASE_MODEL_HUB_REPO',
|
| 75 |
+
'beomi/Llama-3-Open-Ko-8B-gguf'
|
| 76 |
+
)
|
| 77 |
+
base_model_filename = getattr(
|
| 78 |
+
self.config,
|
| 79 |
+
'BASE_MODEL_HUB_FILENAME',
|
| 80 |
+
'ggml-model-Q4_K_M.gguf'
|
| 81 |
+
)
|
| 82 |
+
model_cache_dir = getattr(self.config, 'MODEL_CACHE_DIR', '.cache/models')
|
| 83 |
+
|
| 84 |
+
logger.info(f"๐ฅ Base ๋ชจ๋ธ ๋ค์ด๋ก๋: {base_model_repo}")
|
| 85 |
+
logger.info(f" ํ์ผ๋ช
: {base_model_filename}")
|
| 86 |
+
|
| 87 |
+
from huggingface_hub import hf_hub_download
|
| 88 |
+
|
| 89 |
+
model_path = hf_hub_download(
|
| 90 |
+
repo_id=base_model_repo,
|
| 91 |
+
filename=base_model_filename,
|
| 92 |
+
cache_dir=model_cache_dir,
|
| 93 |
+
local_dir=model_cache_dir,
|
| 94 |
+
local_dir_use_symlinks=False
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
logger.info(f"โ
Base ๋ชจ๋ธ ๋ค์ด๋ก๋ ์๋ฃ: {model_path}")
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
# === ๋ก์ปฌ ํ์ผ ์ฌ์ฉ ===
|
| 101 |
+
model_path = self.model_path
|
| 102 |
+
|
| 103 |
+
if not os.path.exists(model_path):
|
| 104 |
+
raise FileNotFoundError(
|
| 105 |
+
f"โ ๋ก์ปฌ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {model_path}\n"
|
| 106 |
+
f" USE_MODEL_HUB=true๋ก ์ค์ ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ์ ์ค๋นํ์ธ์."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
logger.info(f"๐ ๋ก์ปฌ Base ๋ชจ๋ธ ์ฌ์ฉ: {model_path}")
|
| 110 |
+
|
| 111 |
+
# === ๊ณตํต: ๋ชจ๋ธ ๋ก๋ ===
|
| 112 |
+
logger.info(f"๐ Base GGUF ๋ชจ๋ธ ๋ก๋ ์ค...")
|
| 113 |
+
logger.info(f" GPU ๋ ์ด์ด: {self.n_gpu_layers}")
|
| 114 |
+
logger.info(f" ์ปจํ
์คํธ: {self.n_ctx}")
|
| 115 |
+
|
| 116 |
+
self.model = Llama(
|
| 117 |
+
model_path=model_path,
|
| 118 |
+
n_gpu_layers=self.n_gpu_layers,
|
| 119 |
+
n_ctx=self.n_ctx,
|
| 120 |
+
n_threads=self.n_threads,
|
| 121 |
+
verbose=True,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# ์ค์ ์ ์ฉ๋ n_ctx ํ์ธ
|
| 125 |
+
actual_n_ctx = self.model.n_ctx()
|
| 126 |
+
logger.info("โ
Base GGUF ๋ชจ๋ธ ๋ก๋ ์๋ฃ!")
|
| 127 |
+
logger.info(f" - ๋ชจ๋ธ: {base_model_repo if use_model_hub else 'local'}")
|
| 128 |
+
logger.info(f" - ์ค์ ํ n_ctx: {self.n_ctx}")
|
| 129 |
+
logger.info(f" - ์ค์ n_ctx: {actual_n_ctx}")
|
| 130 |
+
|
| 131 |
+
if actual_n_ctx < self.n_ctx:
|
| 132 |
+
logger.warning(f"โ ๏ธ n_ctx๊ฐ ์์๋ณด๋ค ์์ต๋๋ค: {actual_n_ctx} < {self.n_ctx}")
|
| 133 |
+
|
| 134 |
+
except FileNotFoundError as e:
|
| 135 |
+
logger.error(f"โ ๋ชจ๋ธ ๏ฟฝ๏ฟฝ๏ฟฝ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {e}")
|
| 136 |
+
raise
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}")
|
| 139 |
+
raise RuntimeError(f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 140 |
+
|
| 141 |
+
def format_prompt(
|
| 142 |
+
self,
|
| 143 |
+
question: str,
|
| 144 |
+
context: Optional[str] = None,
|
| 145 |
+
system_prompt: Optional[str] = None
|
| 146 |
+
) -> str:
|
| 147 |
+
"""GGUF ๋ชจ๋ธ์ฉ ๊ฐ๋จํ ํ๋กฌํํธ ํฌ๋งทํ
"""
|
| 148 |
+
if system_prompt is None:
|
| 149 |
+
system_prompt = self.system_prompt
|
| 150 |
+
|
| 151 |
+
if context is not None:
|
| 152 |
+
user_message = f"์ฐธ๊ณ ๋ฌธ์:\n{context}\n\n์ง๋ฌธ: {question}"
|
| 153 |
+
else:
|
| 154 |
+
user_message = question
|
| 155 |
+
|
| 156 |
+
formatted_prompt = f"""### ์์คํ
|
| 157 |
+
{system_prompt}
|
| 158 |
+
|
| 159 |
+
### ์ฌ์ฉ์
|
| 160 |
+
{user_message}
|
| 161 |
+
|
| 162 |
+
### ๋ต๋ณ
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
return formatted_prompt
|
| 166 |
+
|
| 167 |
+
def generate(
|
| 168 |
+
self,
|
| 169 |
+
prompt: str,
|
| 170 |
+
max_new_tokens: Optional[int] = None,
|
| 171 |
+
temperature: Optional[float] = None,
|
| 172 |
+
top_p: Optional[float] = None,
|
| 173 |
+
) -> str:
|
| 174 |
+
"""ํ๋กฌํํธ๋ฅผ ์
๋ ฅ๋ฐ์ ์๋ต ์์ฑ"""
|
| 175 |
+
if self.model is None:
|
| 176 |
+
raise RuntimeError(
|
| 177 |
+
"๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. load_model()์ ๋จผ์ ํธ์ถํ์ธ์."
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if max_new_tokens is None:
|
| 181 |
+
max_new_tokens = self.max_new_tokens
|
| 182 |
+
if temperature is None:
|
| 183 |
+
temperature = self.temperature
|
| 184 |
+
if top_p is None:
|
| 185 |
+
top_p = self.top_p
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
logger.info(f"๐ ์์ฑ ์์ (max_tokens={max_new_tokens}, temp={temperature})")
|
| 189 |
+
start_time = time.time()
|
| 190 |
+
|
| 191 |
+
output = self.model(
|
| 192 |
+
prompt,
|
| 193 |
+
max_tokens=max_new_tokens,
|
| 194 |
+
temperature=temperature,
|
| 195 |
+
top_p=top_p,
|
| 196 |
+
echo=False,
|
| 197 |
+
stop=[
|
| 198 |
+
"###", "\n\n###",
|
| 199 |
+
"### ์ฌ์ฉ์", "\n์ฌ์ฉ์:",
|
| 200 |
+
"</s>",
|
| 201 |
+
"ํ๊ตญ์ด ๋ต๋ณ", "ํ๊ตญ์ด๋ก ๋ต๋ณ", "์ง์นจ:",
|
| 202 |
+
"๋ฌธ์ฅ", "(๋ฌธ์ฅ",
|
| 203 |
+
"\n\n",
|
| 204 |
+
"?",
|
| 205 |
+
"์?", "๊น?", "๋์?", "์ต๋๊น?"
|
| 206 |
+
],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
elapsed = time.time() - start_time
|
| 210 |
+
logger.info(f"โ
์์ฑ ์๋ฃ: {elapsed:.2f}์ด")
|
| 211 |
+
|
| 212 |
+
response = output['choices'][0]['text'].strip()
|
| 213 |
+
|
| 214 |
+
logger.info(f"๐ ์๋ต ๊ธธ์ด: {len(response)} ๊ธ์")
|
| 215 |
+
return response
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"โ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 219 |
+
raise RuntimeError(f"ํ
์คํธ ์์ฑ ์คํจ: {e}")
|
| 220 |
+
|
| 221 |
+
def chat(
|
| 222 |
+
self,
|
| 223 |
+
question: str,
|
| 224 |
+
context: Optional[str] = None,
|
| 225 |
+
system_prompt=None,
|
| 226 |
+
**kwargs
|
| 227 |
+
) -> str:
|
| 228 |
+
"""์ง๋ฌธ์ ๋ํ ์๋ต ์์ฑ"""
|
| 229 |
+
prompt = self.format_prompt(
|
| 230 |
+
question=question,
|
| 231 |
+
context=context,
|
| 232 |
+
system_prompt=system_prompt
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
response = self.generate(prompt, **kwargs)
|
| 236 |
+
|
| 237 |
+
return response
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class GGUFBaseRAGPipeline:
|
| 241 |
+
"""
|
| 242 |
+
Base ๋ชจ๋ธ + RAG ํ์ดํ๋ผ์ธ
|
| 243 |
+
|
| 244 |
+
โ
Base ๋ชจ๋ธ ์ฌ์ฉ (beomi/Llama-3-Open-Ko-8B)
|
| 245 |
+
โ
RAG ์ ์ง
|
| 246 |
+
โ
๊ธฐ์กด generator_gguf.py์ ๋์ผํ ๊ธฐ๋ฅ
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(
|
| 250 |
+
self,
|
| 251 |
+
config=None,
|
| 252 |
+
model: str = None,
|
| 253 |
+
top_k: int = None,
|
| 254 |
+
n_gpu_layers: int = None,
|
| 255 |
+
n_ctx: int = None,
|
| 256 |
+
n_threads: int = None,
|
| 257 |
+
max_new_tokens: int = None,
|
| 258 |
+
temperature: float = None,
|
| 259 |
+
top_p: float = None,
|
| 260 |
+
search_mode: str = None,
|
| 261 |
+
alpha: float = None
|
| 262 |
+
):
|
| 263 |
+
"""์ด๊ธฐํ"""
|
| 264 |
+
self.config = config or RAGConfig()
|
| 265 |
+
|
| 266 |
+
# ๊ฒ์ ์ค์
|
| 267 |
+
self.top_k = top_k or getattr(self.config, 'DEFAULT_TOP_K', 10)
|
| 268 |
+
self.search_mode = search_mode or getattr(self.config, 'DEFAULT_SEARCH_MODE', 'hybrid_rerank')
|
| 269 |
+
self.alpha = alpha if alpha is not None else getattr(self.config, 'DEFAULT_ALPHA', 0.5)
|
| 270 |
+
|
| 271 |
+
# Retriever ์ด๊ธฐํ
|
| 272 |
+
logger.info("RAGRetriever ์ด๊ธฐํ ์ค...")
|
| 273 |
+
from src.retriever.retriever import RAGRetriever
|
| 274 |
+
self.retriever = RAGRetriever(config=self.config)
|
| 275 |
+
|
| 276 |
+
# GGUF ์ค์
|
| 277 |
+
gguf_n_gpu_layers = n_gpu_layers if n_gpu_layers is not None else getattr(self.config, 'GGUF_N_GPU_LAYERS', 35)
|
| 278 |
+
gguf_n_ctx = n_ctx if n_ctx is not None else getattr(self.config, 'GGUF_N_CTX', 2048)
|
| 279 |
+
gguf_n_threads = n_threads if n_threads is not None else getattr(self.config, 'GGUF_N_THREADS', 4)
|
| 280 |
+
gguf_max_new_tokens = max_new_tokens if max_new_tokens is not None else getattr(self.config, 'GGUF_MAX_NEW_TOKENS', 512)
|
| 281 |
+
gguf_temperature = temperature if temperature is not None else getattr(self.config, 'GGUF_TEMPERATURE', 0.7)
|
| 282 |
+
gguf_top_p = top_p if top_p is not None else getattr(self.config, 'GGUF_TOP_P', 0.9)
|
| 283 |
+
|
| 284 |
+
# ๋ชจ๋ธ ๊ฒฝ๋ก (์ฌ์ฉ ์ ํจ, Hub์์ ๋ค์ด๋ก๋)
|
| 285 |
+
gguf_model_path = getattr(self.config, 'GGUF_MODEL_PATH', '.cache/models/llama-3-ko-8b.gguf')
|
| 286 |
+
|
| 287 |
+
# ์์คํ
ํ๋กฌํํธ
|
| 288 |
+
system_prompt = getattr(self.config, 'SYSTEM_PROMPT', '๋น์ ์ ํ๊ตญ ๊ณต๊ณต๊ธฐ๊ด ์ฌ์
์ ์์ ๋ถ์ ์ ๋ฌธ๊ฐ์
๋๋ค.')
|
| 289 |
+
|
| 290 |
+
# GGUFGenerator ์ด๊ธฐํ
|
| 291 |
+
logger.info("GGUFGenerator ์ด๊ธฐํ ์ค... (Base ๋ชจ๋ธ)")
|
| 292 |
+
logger.info(f" GPU ๋ ์ด์ด: {gguf_n_gpu_layers}")
|
| 293 |
+
logger.info(f" ์ปจํ
์คํธ: {gguf_n_ctx}")
|
| 294 |
+
|
| 295 |
+
self.generator = GGUFGenerator(
|
| 296 |
+
model_path=gguf_model_path,
|
| 297 |
+
n_gpu_layers=gguf_n_gpu_layers,
|
| 298 |
+
n_ctx=gguf_n_ctx,
|
| 299 |
+
n_threads=gguf_n_threads,
|
| 300 |
+
config=self.config,
|
| 301 |
+
max_new_tokens=gguf_max_new_tokens,
|
| 302 |
+
temperature=gguf_temperature,
|
| 303 |
+
top_p=gguf_top_p,
|
| 304 |
+
system_prompt=system_prompt
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# ๋ชจ๋ธ ๋ก๋
|
| 308 |
+
logger.info("Base GGUF ๋ชจ๋ธ ๋ก๋ ์ค...")
|
| 309 |
+
self.generator.load_model()
|
| 310 |
+
|
| 311 |
+
# Router
|
| 312 |
+
self.router = QueryRouter()
|
| 313 |
+
|
| 314 |
+
# ๋ํ ํ์คํ ๋ฆฌ
|
| 315 |
+
self.chat_history: List[Dict] = []
|
| 316 |
+
|
| 317 |
+
# ๋ง์ง๋ง ๊ฒ์ ๊ฒฐ๊ณผ
|
| 318 |
+
self._last_retrieved_docs = []
|
| 319 |
+
|
| 320 |
+
logger.info("โ
GGUFBaseRAGPipeline ์ด๊ธฐํ ์๋ฃ")
|
| 321 |
+
logger.info(f" - ๊ฒ์ ๋ชจ๋: {self.search_mode}")
|
| 322 |
+
logger.info(f" - ๊ธฐ๋ณธ top_k: {self.top_k}")
|
| 323 |
+
|
| 324 |
+
def _retrieve_and_format(self, query: str) -> str:
|
| 325 |
+
"""๊ฒ์ ์ํ ๋ฐ ์ปจํ
์คํธ ํฌ๋งทํ
"""
|
| 326 |
+
# ๊ฒ์ ๋ชจ๋์ ๋ฐ๋ผ ๋ฌธ์ ๊ฒ์
|
| 327 |
+
if self.search_mode == "embedding":
|
| 328 |
+
docs = self.retriever.search(query, top_k=self.top_k)
|
| 329 |
+
elif self.search_mode == "embedding_rerank":
|
| 330 |
+
docs = self.retriever.search_with_rerank(query, top_k=self.top_k)
|
| 331 |
+
elif self.search_mode == "hybrid":
|
| 332 |
+
docs = self.retriever.hybrid_search(
|
| 333 |
+
query, top_k=self.top_k, alpha=self.alpha
|
| 334 |
+
)
|
| 335 |
+
elif self.search_mode == "hybrid_rerank":
|
| 336 |
+
docs = self.retriever.hybrid_search_with_rerank(
|
| 337 |
+
query, top_k=self.top_k, alpha=self.alpha
|
| 338 |
+
)
|
| 339 |
+
else:
|
| 340 |
+
docs = self.retriever.search(query, top_k=self.top_k)
|
| 341 |
+
|
| 342 |
+
# ๋ง์ง๋ง ๊ฒ์ ๊ฒฐ๊ณผ ์ ์ฅ
|
| 343 |
+
self._last_retrieved_docs = docs
|
| 344 |
+
|
| 345 |
+
# ์ปจํ
์คํธ ํฌ๋งทํ
|
| 346 |
+
return self._format_context(docs)
|
| 347 |
+
|
| 348 |
+
def _format_context(self, retrieved_docs: list) -> str:
|
| 349 |
+
"""๊ฒ์๋ ๋ฌธ์๋ฅผ ์ปจํ
์คํธ๋ก ๋ณํ"""
|
| 350 |
+
if not retrieved_docs:
|
| 351 |
+
return "๊ด๋ จ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."
|
| 352 |
+
|
| 353 |
+
context_parts = []
|
| 354 |
+
max_context_chars = 8000
|
| 355 |
+
|
| 356 |
+
current_length = 0
|
| 357 |
+
for i, doc in enumerate(retrieved_docs, 1):
|
| 358 |
+
doc_text = f"[๋ฌธ์ {i}]\n{doc['content']}\n"
|
| 359 |
+
doc_length = len(doc_text)
|
| 360 |
+
|
| 361 |
+
if current_length + doc_length > max_context_chars:
|
| 362 |
+
logger.warning(f"โ ๏ธ ์ปจํ
์คํธ ๊ธธ์ด ์ ํ: {i-1}๊ฐ ๋ฌธ์๋ง ์ฌ์ฉ")
|
| 363 |
+
break
|
| 364 |
+
|
| 365 |
+
context_parts.append(doc_text)
|
| 366 |
+
current_length += doc_length
|
| 367 |
+
|
| 368 |
+
return "\n".join(context_parts)
|
| 369 |
+
|
| 370 |
+
def _format_sources(self, retrieved_docs: list) -> list:
|
| 371 |
+
"""๊ฒ์๋ ๋ฌธ์๋ฅผ sources ํ์์ผ๋ก ๋ณํ"""
|
| 372 |
+
sources = []
|
| 373 |
+
for doc in retrieved_docs:
|
| 374 |
+
source_info = {
|
| 375 |
+
'content': doc['content'],
|
| 376 |
+
'metadata': doc['metadata'],
|
| 377 |
+
'filename': doc.get('filename', 'N/A'),
|
| 378 |
+
'organization': doc.get('organization', 'N/A')
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
if 'rerank_score' in doc:
|
| 382 |
+
source_info['score'] = doc['rerank_score']
|
| 383 |
+
source_info['score_type'] = 'rerank'
|
| 384 |
+
elif 'hybrid_score' in doc:
|
| 385 |
+
source_info['score'] = doc['hybrid_score']
|
| 386 |
+
source_info['score_type'] = 'hybrid'
|
| 387 |
+
elif 'relevance_score' in doc:
|
| 388 |
+
source_info['score'] = doc['relevance_score']
|
| 389 |
+
source_info['score_type'] = 'embedding'
|
| 390 |
+
else:
|
| 391 |
+
source_info['score'] = 0
|
| 392 |
+
source_info['score_type'] = 'unknown'
|
| 393 |
+
|
| 394 |
+
sources.append(source_info)
|
| 395 |
+
|
| 396 |
+
return sources
|
| 397 |
+
|
| 398 |
+
def _estimate_usage(self, query: str, answer: str) -> dict:
|
| 399 |
+
"""ํ ํฐ ์ฌ์ฉ๋ ์ถ์ """
|
| 400 |
+
prompt_tokens = len(query.split()) * 2
|
| 401 |
+
completion_tokens = len(answer.split()) * 2
|
| 402 |
+
|
| 403 |
+
return {
|
| 404 |
+
'total_tokens': prompt_tokens + completion_tokens,
|
| 405 |
+
'prompt_tokens': prompt_tokens,
|
| 406 |
+
'completion_tokens': completion_tokens
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
def generate_answer(
|
| 410 |
+
self,
|
| 411 |
+
query: str,
|
| 412 |
+
top_k: int = None,
|
| 413 |
+
search_mode: str = None,
|
| 414 |
+
alpha: float = None
|
| 415 |
+
) -> dict:
|
| 416 |
+
"""๋ต๋ณ ์์ฑ (Base ๋ชจ๋ธ + RAG)"""
|
| 417 |
+
try:
|
| 418 |
+
start_time = time.time()
|
| 419 |
+
|
| 420 |
+
# ํ๋ผ๋ฏธํฐ ์ค์
|
| 421 |
+
if top_k is not None:
|
| 422 |
+
self.top_k = top_k
|
| 423 |
+
if search_mode is not None:
|
| 424 |
+
self.search_mode = search_mode
|
| 425 |
+
if alpha is not None:
|
| 426 |
+
self.alpha = alpha
|
| 427 |
+
|
| 428 |
+
# Router๋ก ๊ฒ์ ์ฌ๋ถ ๊ฒฐ์
|
| 429 |
+
classification = self.router.classify(query)
|
| 430 |
+
query_type = classification['type']
|
| 431 |
+
|
| 432 |
+
logger.info(f"๐ ๋ถ๋ฅ: {query_type} (์ ๋ขฐ๋: {classification['confidence']:.2f})")
|
| 433 |
+
|
| 434 |
+
# ํ์
๋ณ ์ฒ๋ฆฌ
|
| 435 |
+
if query_type in ['greeting', 'thanks', 'out_of_scope']:
|
| 436 |
+
# ๊ฒ์ ์คํต
|
| 437 |
+
context = None
|
| 438 |
+
used_retrieval = False
|
| 439 |
+
self._last_retrieved_docs = []
|
| 440 |
+
|
| 441 |
+
# ๋์ ํ๋กฌํํธ
|
| 442 |
+
system_prompt = PromptManager.get_prompt(query_type, model_type="gguf")
|
| 443 |
+
logger.info(f"โญ๏ธ RAG ์คํต: {query_type}")
|
| 444 |
+
|
| 445 |
+
elif query_type == 'document':
|
| 446 |
+
# RAG ์ํ
|
| 447 |
+
context = self._retrieve_and_format(query)
|
| 448 |
+
used_retrieval = True
|
| 449 |
+
|
| 450 |
+
# ๋์ ํ๋กฌํํธ
|
| 451 |
+
system_prompt = PromptManager.get_prompt('document', model_type="gguf")
|
| 452 |
+
logger.info(f"๐ RAG ์ํ: {len(self._last_retrieved_docs)}๊ฐ ๋ฌธ์")
|
| 453 |
+
|
| 454 |
+
# ๋ต๋ณ ์์ฑ
|
| 455 |
+
answer = self.generator.chat(
|
| 456 |
+
question=query,
|
| 457 |
+
context=context,
|
| 458 |
+
system_prompt=system_prompt
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
elapsed_time = time.time() - start_time
|
| 462 |
+
|
| 463 |
+
# ๋ํ ํ์คํ ๋ฆฌ ์ถ๊ฐ
|
| 464 |
+
self.chat_history.append({"role": "user", "content": query})
|
| 465 |
+
self.chat_history.append({"role": "assistant", "content": answer})
|
| 466 |
+
|
| 467 |
+
# ๊ฒฐ๊ณผ ๋ฐํ
|
| 468 |
+
return {
|
| 469 |
+
'answer': answer,
|
| 470 |
+
'sources': self._format_sources(self._last_retrieved_docs),
|
| 471 |
+
'used_retrieval': used_retrieval,
|
| 472 |
+
'query_type': query_type,
|
| 473 |
+
'search_mode': self.search_mode if used_retrieval else 'direct',
|
| 474 |
+
'routing_info': classification,
|
| 475 |
+
'elapsed_time': elapsed_time,
|
| 476 |
+
'usage': self._estimate_usage(query, answer)
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
except Exception as e:
|
| 480 |
+
logger.error(f"โ ๋ต๋ณ ์์ฑ ์คํจ: {e}")
|
| 481 |
+
import traceback
|
| 482 |
+
traceback.print_exc()
|
| 483 |
+
raise RuntimeError(f"๋ต๋ณ ์์ฑ ์คํจ: {str(e)}") from e
|
| 484 |
+
|
| 485 |
+
def chat(self, query: str) -> str:
|
| 486 |
+
"""๊ฐ๋จํ ๋ํ ์ธํฐํ์ด์ค"""
|
| 487 |
+
result = self.generate_answer(query)
|
| 488 |
+
return result['answer']
|
| 489 |
+
|
| 490 |
+
def clear_history(self):
|
| 491 |
+
"""๋ํ ํ์คํ ๋ฆฌ ์ด๊ธฐํ"""
|
| 492 |
+
self.chat_history = []
|
| 493 |
+
logger.info("๐๏ธ ๋ํ ํ์คํ ๋ฆฌ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค.")
|
| 494 |
+
|
| 495 |
+
def get_history(self) -> List[Dict]:
|
| 496 |
+
"""๋ํ ํ์คํ ๋ฆฌ ๋ฐํ"""
|
| 497 |
+
return self.chat_history.copy()
|
| 498 |
+
|
| 499 |
+
def set_search_config(
|
| 500 |
+
self,
|
| 501 |
+
search_mode: str = None,
|
| 502 |
+
top_k: int = None,
|
| 503 |
+
alpha: float = None
|
| 504 |
+
):
|
| 505 |
+
"""๊ฒ์ ์ค์ ๋ณ๊ฒฝ"""
|
| 506 |
+
if search_mode is not None:
|
| 507 |
+
self.search_mode = search_mode
|
| 508 |
+
if top_k is not None:
|
| 509 |
+
self.top_k = top_k
|
| 510 |
+
if alpha is not None:
|
| 511 |
+
self.alpha = alpha
|
| 512 |
+
|
| 513 |
+
logger.info(
|
| 514 |
+
f"๐ง ๊ฒ์ ์ค์ ๋ณ๊ฒฝ: mode={self.search_mode}, "
|
| 515 |
+
f"top_k={self.top_k}, alpha={self.alpha}"
|
| 516 |
+
)
|
src/generator/generator_gguf_no_rag.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llama_cpp import Llama
|
| 2 |
+
from typing import Optional, Dict, Any, List
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from src.utils.config import RAGConfig
|
| 8 |
+
from src.router.query_router import QueryRouter
|
| 9 |
+
from src.prompts.dynamic_prompts import PromptManager
|
| 10 |
+
|
| 11 |
+
# ๋ก๊น
์ค์
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GGUFGenerator:
|
| 17 |
+
"""
|
| 18 |
+
GGUF ๊ธฐ๋ฐ Llama-3 ์์ฑ๊ธฐ
|
| 19 |
+
|
| 20 |
+
llama.cpp๋ฅผ ์ฌ์ฉํ์ฌ GGUF ํฌ๋งท ๋ชจ๋ธ์ ๋ก๋ํ๊ณ
|
| 21 |
+
์
์ฐฐ ๊ด๋ จ ์ง์์๋ต์ ์ํํฉ๋๋ค.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model_path: str,
|
| 27 |
+
n_gpu_layers: int = 0,
|
| 28 |
+
n_ctx: int = 8192,
|
| 29 |
+
n_threads: int = 8,
|
| 30 |
+
config = None,
|
| 31 |
+
max_new_tokens: int = 256,
|
| 32 |
+
temperature: float = 0.7,
|
| 33 |
+
top_p: float = 0.9,
|
| 34 |
+
system_prompt: str = "๋น์ ์ RFP(์ ์์์ฒญ์) ๋ถ์ ๋ฐ ์์ฝ ์ ๋ฌธ๊ฐ์
๋๋ค."
|
| 35 |
+
):
|
| 36 |
+
"""์์ฑ๊ธฐ ์ด๊ธฐํ"""
|
| 37 |
+
self.config = config or RAGConfig()
|
| 38 |
+
self.model_path = model_path
|
| 39 |
+
self.n_gpu_layers = n_gpu_layers
|
| 40 |
+
self.n_ctx = n_ctx
|
| 41 |
+
self.n_threads = n_threads
|
| 42 |
+
self.max_new_tokens = max_new_tokens
|
| 43 |
+
self.temperature = temperature
|
| 44 |
+
self.top_p = top_p
|
| 45 |
+
self.system_prompt = system_prompt
|
| 46 |
+
|
| 47 |
+
# ๋ชจ๋ธ (๋์ค์ ๋ก๋)
|
| 48 |
+
self.model = None
|
| 49 |
+
|
| 50 |
+
logger.info(f"GGUFGenerator ์ด๊ธฐํ ์๋ฃ")
|
| 51 |
+
|
| 52 |
+
def load_model(self) -> None:
|
| 53 |
+
"""GGUF ๋ชจ๋ธ ๋ก๋"""
|
| 54 |
+
|
| 55 |
+
# ์ค๋ณต ๋ก๋ ๋ฐฉ์ง
|
| 56 |
+
if self.model is not None:
|
| 57 |
+
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋์ด ์์ต๋๋ค.")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
# Config์์ USE_MODEL_HUB ํ์ธ
|
| 62 |
+
use_model_hub = getattr(self.config, 'USE_MODEL_HUB', True)
|
| 63 |
+
|
| 64 |
+
# Model Hub ์ฌ์ฉ ์ฌ๋ถ์ ๋ฐ๋ผ ๊ฒฝ๋ก ๊ฒฐ์
|
| 65 |
+
if use_model_hub:
|
| 66 |
+
# === Model Hub์์ ๋ค์ด๋ก๋ ===
|
| 67 |
+
model_hub_repo = getattr(self.config, 'MODEL_HUB_REPO', 'beomi/Llama-3-Open-Ko-8B-gguf')
|
| 68 |
+
model_hub_filename = getattr(self.config, 'MODEL_HUB_FILENAME', 'ggml-model-Q4_K_M.gguf')
|
| 69 |
+
model_cache_dir = getattr(self.config, 'MODEL_CACHE_DIR', '.cache/models')
|
| 70 |
+
|
| 71 |
+
logger.info(f"๐ฅ Model Hub์์ ๋ค์ด๋ก๋: {model_hub_repo}")
|
| 72 |
+
|
| 73 |
+
from huggingface_hub import hf_hub_download
|
| 74 |
+
|
| 75 |
+
model_path = hf_hub_download(
|
| 76 |
+
repo_id=model_hub_repo,
|
| 77 |
+
filename=model_hub_filename,
|
| 78 |
+
cache_dir=model_cache_dir,
|
| 79 |
+
local_dir=model_cache_dir,
|
| 80 |
+
local_dir_use_symlinks=False
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
logger.info(f"โ
๋ค์ด๋ก๋ ์๋ฃ: {model_path}")
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
# === ๋ก์ปฌ ํ์ผ ์ฌ์ฉ ===
|
| 87 |
+
model_path = self.model_path
|
| 88 |
+
|
| 89 |
+
if not os.path.exists(model_path):
|
| 90 |
+
raise FileNotFoundError(
|
| 91 |
+
f"โ ๋ก์ปฌ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {model_path}\n"
|
| 92 |
+
f" USE_MODEL_HUB=true๋ก ์ค์ ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ์ ์ค๋นํ์ธ์."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
logger.info(f"๐ ๋ก์ปฌ ๋ชจ๋ธ ์ฌ์ฉ: {model_path}")
|
| 96 |
+
|
| 97 |
+
# === ๊ณตํต: ๋ชจ๋ธ ๋ก๋ ===
|
| 98 |
+
logger.info(f"๐ GGUF ๋ชจ๋ธ ๋ก๋ ์ค...")
|
| 99 |
+
logger.info(f" GPU ๋ ์ด์ด: {self.n_gpu_layers}")
|
| 100 |
+
logger.info(f" ์ปจํ
์คํธ: {self.n_ctx}")
|
| 101 |
+
|
| 102 |
+
self.model = Llama(
|
| 103 |
+
model_path=model_path,
|
| 104 |
+
n_gpu_layers=self.n_gpu_layers,
|
| 105 |
+
n_ctx=self.n_ctx,
|
| 106 |
+
n_threads=self.n_threads,
|
| 107 |
+
verbose=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# ์ค์ ์ ์ฉ๋ n_ctx ํ์ธ
|
| 111 |
+
actual_n_ctx = self.model.n_ctx()
|
| 112 |
+
logger.info("โ
GGUF ๋ชจ๋ธ ๋ก๋ ์๋ฃ!")
|
| 113 |
+
logger.info(f" - ์ค์ ํ n_ctx: {self.n_ctx}")
|
| 114 |
+
logger.info(f" - ์ค์ n_ctx: {actual_n_ctx}")
|
| 115 |
+
|
| 116 |
+
if actual_n_ctx < self.n_ctx:
|
| 117 |
+
logger.warning(f"โ ๏ธ n_ctx๊ฐ ์์๋ณด๋ค ์์ต๋๋ค: {actual_n_ctx} < {self.n_ctx}")
|
| 118 |
+
logger.warning(f" ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ์ผ ์ ์์ต๋๋ค. n_gpu_layers๋ฅผ ์ค์ฌ๋ณด์ธ์.")
|
| 119 |
+
|
| 120 |
+
except FileNotFoundError as e:
|
| 121 |
+
logger.error(f"โ ๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {e}")
|
| 122 |
+
raise
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}")
|
| 125 |
+
raise RuntimeError(f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 126 |
+
|
| 127 |
+
def format_prompt(
|
| 128 |
+
self,
|
| 129 |
+
question: str,
|
| 130 |
+
context: Optional[str] = None,
|
| 131 |
+
system_prompt: Optional[str] = None
|
| 132 |
+
) -> str:
|
| 133 |
+
"""GGUF ๋ชจ๋ธ์ฉ ๊ฐ๋จํ ํ๋กฌํํธ ํฌ๋งทํ
"""
|
| 134 |
+
# ์์คํ
ํ๋กฌํํธ ์ค์
|
| 135 |
+
if system_prompt is None:
|
| 136 |
+
system_prompt = self.system_prompt
|
| 137 |
+
|
| 138 |
+
# ์ปจํ
์คํธ ํฌํจ ์ฌ๋ถ
|
| 139 |
+
if context is not None:
|
| 140 |
+
user_message = f"์ฐธ๊ณ ๋ฌธ์:\n{context}\n\n์ง๋ฌธ: {question}"
|
| 141 |
+
else:
|
| 142 |
+
user_message = question
|
| 143 |
+
|
| 144 |
+
# ๊ฐ๋จํ ํ๊ตญ์ด ํ
ํ๋ฆฟ
|
| 145 |
+
formatted_prompt = f"""### ์์คํ
|
| 146 |
+
{system_prompt}
|
| 147 |
+
|
| 148 |
+
### ์ฌ์ฉ์
|
| 149 |
+
{user_message}
|
| 150 |
+
|
| 151 |
+
### ๋ต๋ณ
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
return formatted_prompt
|
| 155 |
+
|
| 156 |
+
def generate(
|
| 157 |
+
self,
|
| 158 |
+
prompt: str,
|
| 159 |
+
max_new_tokens: Optional[int] = None,
|
| 160 |
+
temperature: Optional[float] = None,
|
| 161 |
+
top_p: Optional[float] = None,
|
| 162 |
+
) -> str:
|
| 163 |
+
"""ํ๋กฌํํธ๋ฅผ ์
๋ ฅ๋ฐ์ ์๋ต ์์ฑ"""
|
| 164 |
+
# ๋ชจ๋ธ ๋ก๋ ํ์ธ
|
| 165 |
+
if self.model is None:
|
| 166 |
+
raise RuntimeError(
|
| 167 |
+
"๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. load_model()์ ๋จผ์ ํธ์ถํ์ธ์."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# ํ๋ผ๋ฏธํฐ ์ค์
|
| 171 |
+
if max_new_tokens is None:
|
| 172 |
+
max_new_tokens = self.max_new_tokens
|
| 173 |
+
if temperature is None:
|
| 174 |
+
temperature = self.temperature
|
| 175 |
+
if top_p is None:
|
| 176 |
+
top_p = self.top_p
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
logger.info(f"๐ ์์ฑ ์์ (max_tokens={max_new_tokens}, temp={temperature})")
|
| 180 |
+
start_time = time.time()
|
| 181 |
+
|
| 182 |
+
# ์์ฑ
|
| 183 |
+
output = self.model(
|
| 184 |
+
prompt,
|
| 185 |
+
max_tokens=max_new_tokens,
|
| 186 |
+
temperature=temperature,
|
| 187 |
+
top_p=top_p,
|
| 188 |
+
echo=False,
|
| 189 |
+
stop=[
|
| 190 |
+
"###", "\n\n###",
|
| 191 |
+
"### ์ฌ์ฉ์", "\n์ฌ์ฉ์:",
|
| 192 |
+
"</s>",
|
| 193 |
+
"ํ๊ตญ์ด ๋ต๋ณ", "ํ๊ตญ์ด๋ก ๋ต๋ณ", "์ง์นจ:",
|
| 194 |
+
"๋ฌธ์ฅ", "(๋ฌธ์ฅ",
|
| 195 |
+
"\n\n",
|
| 196 |
+
"?",
|
| 197 |
+
"์?", "๊น?", "๋์?", "์ต๋๊น?"
|
| 198 |
+
],
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
elapsed = time.time() - start_time
|
| 202 |
+
logger.info(f"โ
์์ฑ ์๋ฃ: {elapsed:.2f}์ด")
|
| 203 |
+
|
| 204 |
+
# ์๋ต ์ถ์ถ
|
| 205 |
+
response = output['choices'][0]['text'].strip()
|
| 206 |
+
|
| 207 |
+
logger.info(f"๐ ์๋ต ๊ธธ์ด: {len(response)} ๊ธ์")
|
| 208 |
+
return response
|
| 209 |
+
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.error(f"โ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| 212 |
+
raise RuntimeError(f"ํ
์คํธ ์์ฑ ์คํจ: {e}")
|
| 213 |
+
|
| 214 |
+
def chat(
|
| 215 |
+
self,
|
| 216 |
+
question: str,
|
| 217 |
+
context: Optional[str] = None,
|
| 218 |
+
system_prompt=None,
|
| 219 |
+
**kwargs
|
| 220 |
+
) -> str:
|
| 221 |
+
"""์ง๋ฌธ์ ๋ํ ์๋ต ์์ฑ (ํตํฉ ๋ฉ์๋)"""
|
| 222 |
+
# ํ๋กฌํํธ ํฌ๋งทํ
|
| 223 |
+
prompt = self.format_prompt(
|
| 224 |
+
question=question,
|
| 225 |
+
context=context,
|
| 226 |
+
system_prompt=system_prompt
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# ์๋ต ์์ฑ
|
| 230 |
+
response = self.generate(prompt, **kwargs)
|
| 231 |
+
|
| 232 |
+
return response
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class GGUFNoRAGPipeline:
|
| 236 |
+
"""
|
| 237 |
+
QLoRA ๋ชจ๋ธ ๋จ๋
ํ์ดํ๋ผ์ธ (RAG ์ ๊ฑฐ)
|
| 238 |
+
|
| 239 |
+
โ
Retriever ์์ ์ ๊ฑฐ
|
| 240 |
+
โ
Router๋ง ์ ์ง (greeting/thanks ์ฒ๋ฆฌ์ฉ)
|
| 241 |
+
โ
์์ ๋ชจ๋ธ ์ฑ๋ฅ๋ง ์ธก์
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
config=None,
|
| 247 |
+
model: str = None,
|
| 248 |
+
top_k: int = None,
|
| 249 |
+
n_gpu_layers: int = None,
|
| 250 |
+
n_ctx: int = None,
|
| 251 |
+
n_threads: int = None,
|
| 252 |
+
max_new_tokens: int = None,
|
| 253 |
+
temperature: float = None,
|
| 254 |
+
top_p: float = None
|
| 255 |
+
):
|
| 256 |
+
"""์ด๊ธฐํ"""
|
| 257 |
+
self.config = config or RAGConfig()
|
| 258 |
+
|
| 259 |
+
# GGUF ์ค์
|
| 260 |
+
gguf_n_gpu_layers = n_gpu_layers if n_gpu_layers is not None else getattr(self.config, 'GGUF_N_GPU_LAYERS', 35)
|
| 261 |
+
gguf_n_ctx = n_ctx if n_ctx is not None else getattr(self.config, 'GGUF_N_CTX', 2048)
|
| 262 |
+
gguf_n_threads = n_threads if n_threads is not None else getattr(self.config, 'GGUF_N_THREADS', 4)
|
| 263 |
+
gguf_max_new_tokens = max_new_tokens if max_new_tokens is not None else getattr(self.config, 'GGUF_MAX_NEW_TOKENS', 512)
|
| 264 |
+
gguf_temperature = temperature if temperature is not None else getattr(self.config, 'GGUF_TEMPERATURE', 0.7)
|
| 265 |
+
gguf_top_p = top_p if top_p is not None else getattr(self.config, 'GGUF_TOP_P', 0.9)
|
| 266 |
+
|
| 267 |
+
# ๋ชจ๋ธ ๊ฒฝ๋ก
|
| 268 |
+
gguf_model_path = getattr(self.config, 'GGUF_MODEL_PATH', '.cache/models/llama-3-ko-8b.gguf')
|
| 269 |
+
|
| 270 |
+
# ์์คํ
ํ๋กฌํํธ
|
| 271 |
+
system_prompt = getattr(self.config, 'SYSTEM_PROMPT', '๋น์ ์ ํ๊ตญ ๊ณต๊ณต๊ธฐ๊ด ์ฌ์
์ ์์ ๋ถ์ ์ ๋ฌธ๊ฐ์
๋๋ค.')
|
| 272 |
+
|
| 273 |
+
# GGUFGenerator ์ด๊ธฐํ
|
| 274 |
+
logger.info("GGUFGenerator ์ด๊ธฐํ ์ค... (RAG ์์)")
|
| 275 |
+
logger.info(f" GPU ๋ ์ด์ด: {gguf_n_gpu_layers}")
|
| 276 |
+
logger.info(f" ์ปจํ
์คํธ: {gguf_n_ctx}")
|
| 277 |
+
logger.info(f" ์ค๋ ๋: {gguf_n_threads}")
|
| 278 |
+
|
| 279 |
+
self.generator = GGUFGenerator(
|
| 280 |
+
model_path=gguf_model_path,
|
| 281 |
+
n_gpu_layers=gguf_n_gpu_layers,
|
| 282 |
+
n_ctx=gguf_n_ctx,
|
| 283 |
+
n_threads=gguf_n_threads,
|
| 284 |
+
config=self.config,
|
| 285 |
+
max_new_tokens=gguf_max_new_tokens,
|
| 286 |
+
temperature=gguf_temperature,
|
| 287 |
+
top_p=gguf_top_p,
|
| 288 |
+
system_prompt=system_prompt
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# ๋ชจ๋ธ ๋ก๋
|
| 292 |
+
logger.info("GGUF ๋ชจ๋ธ ๋ก๋ ์ค...")
|
| 293 |
+
self.generator.load_model()
|
| 294 |
+
|
| 295 |
+
# โ
Retriever ์์ (์์ ์ ๊ฑฐ)
|
| 296 |
+
self.retriever = None
|
| 297 |
+
|
| 298 |
+
# Router (greeting/thanks ์ฒ๋ฆฌ์ฉ)
|
| 299 |
+
self.router = QueryRouter()
|
| 300 |
+
|
| 301 |
+
# ๋ํ ํ์คํ ๋ฆฌ
|
| 302 |
+
self.chat_history: List[Dict] = []
|
| 303 |
+
|
| 304 |
+
logger.info("โ
GGUFNoRAGPipeline ์ด๊ธฐํ ์๋ฃ (RAG ์ ๊ฑฐ)")
|
| 305 |
+
logger.info(" - Retriever: โ ์์")
|
| 306 |
+
logger.info(" - Router: โ
์์ (greeting/thanks์ฉ)")
|
| 307 |
+
|
| 308 |
+
def _estimate_usage(self, query: str, answer: str) -> dict:
|
| 309 |
+
"""ํ ํฐ ์ฌ์ฉ๋ ์ถ์ """
|
| 310 |
+
prompt_tokens = len(query.split()) * 2
|
| 311 |
+
completion_tokens = len(answer.split()) * 2
|
| 312 |
+
|
| 313 |
+
return {
|
| 314 |
+
'total_tokens': prompt_tokens + completion_tokens,
|
| 315 |
+
'prompt_tokens': prompt_tokens,
|
| 316 |
+
'completion_tokens': completion_tokens
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
def generate_answer(
|
| 320 |
+
self,
|
| 321 |
+
query: str,
|
| 322 |
+
top_k: int = None,
|
| 323 |
+
search_mode: str = None,
|
| 324 |
+
alpha: float = None
|
| 325 |
+
) -> dict:
|
| 326 |
+
"""
|
| 327 |
+
๋ต๋ณ ์์ฑ (RAG ์์)
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
query: ์ง๋ฌธ
|
| 331 |
+
top_k: ์ฌ์ฉ ์ ํจ (ํธํ์ฑ์ฉ)
|
| 332 |
+
search_mode: ์ฌ์ฉ ์ ํจ (ํธํ์ฑ์ฉ)
|
| 333 |
+
alpha: ์ฌ์ฉ ์ ํจ (ํธํ์ฑ์ฉ)
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
dict: answer, sources, search_mode, usage, elapsed_time, used_retrieval
|
| 337 |
+
"""
|
| 338 |
+
try:
|
| 339 |
+
start_time = time.time()
|
| 340 |
+
|
| 341 |
+
# Router๋ก ์ง๋ฌธ ๋ถ๋ฅ
|
| 342 |
+
classification = self.router.classify(query)
|
| 343 |
+
query_type = classification['type']
|
| 344 |
+
|
| 345 |
+
logger.info(f"๐ ๋ถ๋ฅ: {query_type} (์ ๋ขฐ๋: {classification['confidence']:.2f})")
|
| 346 |
+
|
| 347 |
+
# ๋์ ํ๋กฌํํธ ์ ํ
|
| 348 |
+
if query_type in ['greeting', 'thanks', 'out_of_scope']:
|
| 349 |
+
system_prompt = PromptManager.get_prompt(query_type, model_type="gguf")
|
| 350 |
+
else:
|
| 351 |
+
system_prompt = PromptManager.get_prompt('document', model_type="gguf")
|
| 352 |
+
|
| 353 |
+
# โ
ํญ์ RAG ์์ด ์์ฑ (context=None)
|
| 354 |
+
answer = self.generator.chat(
|
| 355 |
+
question=query,
|
| 356 |
+
context=None, # โ
์ปจํ
์คํธ ์์
|
| 357 |
+
system_prompt=system_prompt
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
elapsed_time = time.time() - start_time
|
| 361 |
+
|
| 362 |
+
# ๋ํ ํ์คํ ๋ฆฌ ์ถ๊ฐ
|
| 363 |
+
self.chat_history.append({"role": "user", "content": query})
|
| 364 |
+
self.chat_history.append({"role": "assistant", "content": answer})
|
| 365 |
+
|
| 366 |
+
# ๊ฒฐ๊ณผ ๋ฐํ
|
| 367 |
+
return {
|
| 368 |
+
'answer': answer,
|
| 369 |
+
'sources': [], # โ
์์ค ์์
|
| 370 |
+
'used_retrieval': False, # โ
๊ฒ์ ์ ํจ
|
| 371 |
+
'query_type': query_type,
|
| 372 |
+
'search_mode': 'none', # โ
๊ฒ์ ๋ชจ๋ ์์
|
| 373 |
+
'routing_info': classification,
|
| 374 |
+
'elapsed_time': elapsed_time,
|
| 375 |
+
'usage': self._estimate_usage(query, answer)
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
except Exception as e:
|
| 379 |
+
logger.error(f"โ ๋ต๋ณ ์์ฑ ์คํจ: {e}")
|
| 380 |
+
import traceback
|
| 381 |
+
traceback.print_exc()
|
| 382 |
+
raise RuntimeError(f"๋ต๋ณ ์์ฑ ์คํจ: {str(e)}") from e
|
| 383 |
+
|
| 384 |
+
def chat(self, query: str) -> str:
|
| 385 |
+
"""๊ฐ๋จํ ๋ํ ์ธํฐํ์ด์ค"""
|
| 386 |
+
result = self.generate_answer(query)
|
| 387 |
+
return result['answer']
|
| 388 |
+
|
| 389 |
+
def clear_history(self):
|
| 390 |
+
"""๋ํ ํ์คํ ๋ฆฌ ์ด๊ธฐํ"""
|
| 391 |
+
self.chat_history = []
|
| 392 |
+
logger.info("๐๏ธ ๋ํ ํ์คํ ๋ฆฌ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค.")
|
| 393 |
+
|
| 394 |
+
def get_history(self) -> List[Dict]:
|
| 395 |
+
"""๋ํ ํ์คํ ๋ฆฌ ๋ฐํ"""
|
| 396 |
+
return self.chat_history.copy()
|
src/retriever/main.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from RAG_pipeline_v1.rag_config import RAGConfig
|
| 3 |
+
from RAG_pipeline_v1.rag_data_processing import RAGVectorDBPipeline
|
| 4 |
+
from RAG_pipeline_v1.rag_pipeline import RAGPipeline
|
| 5 |
+
from RAG_pipeline_v1.rag_evaluator import RAGEvaluator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
"""๋ฉ์ธ ์คํ ํจ์"""
|
| 10 |
+
|
| 11 |
+
# ===== ํ๊ฒฝ ์ค์ =====
|
| 12 |
+
print("="*60)
|
| 13 |
+
print("RAG ์์คํ
์ด๊ธฐํ")
|
| 14 |
+
print("="*60)
|
| 15 |
+
|
| 16 |
+
os.environ["OPENAI_API_KEY"] = RAGConfig.OPENAI_API_KEY
|
| 17 |
+
|
| 18 |
+
config = RAGConfig()
|
| 19 |
+
config.validate()
|
| 20 |
+
print(config)
|
| 21 |
+
|
| 22 |
+
# ===== 1. Vector DB ๊ตฌ์ถ (์ต์ด 1ํ๋ง) =====
|
| 23 |
+
# ์ฃผ์ ํด์ ํ์ฌ ์คํ
|
| 24 |
+
# print("\n" + "="*60)
|
| 25 |
+
# print("Vector DB ๊ตฌ์ถ")
|
| 26 |
+
# print("="*60)
|
| 27 |
+
# db_pipeline = RAGVectorDBPipeline(config)
|
| 28 |
+
# vectorstore = db_pipeline.build()
|
| 29 |
+
# db_pipeline.test_search()
|
| 30 |
+
|
| 31 |
+
# ===== 2. RAG ํ์ดํ๋ผ์ธ ์ด๊ธฐํ =====
|
| 32 |
+
print("\n" + "="*60)
|
| 33 |
+
print("RAG ํ์ดํ๋ผ์ธ ์ด๊ธฐํ")
|
| 34 |
+
print("="*60)
|
| 35 |
+
|
| 36 |
+
rag = RAGPipeline(config=config)
|
| 37 |
+
|
| 38 |
+
# ===== 3. ํ
์คํธ ์ฟผ๋ฆฌ =====
|
| 39 |
+
print("\n" + "="*60)
|
| 40 |
+
print("ํ
์คํธ ์ฟผ๋ฆฌ")
|
| 41 |
+
print("="*60)
|
| 42 |
+
|
| 43 |
+
test_queries = [
|
| 44 |
+
"ํ์๋ํ๊ต์ ํน์ฑํ ๊ต์กํ๊ฒฝ ๊ตฌ์ถ ์ฌ์
์ ๋ฌด์์ธ๊ฐ์?",
|
| 45 |
+
"์ฌ๋ ์์ ๊ด๋ฆฌ ์์คํ
๊ตฌ์ถ ์ฌ์
์ ์ด๋ค ๊ฒ์ด ์๋์?",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
for query in test_queries:
|
| 49 |
+
result = rag.generate_answer(query)
|
| 50 |
+
rag.print_result(result)
|
| 51 |
+
print("\n")
|
| 52 |
+
|
| 53 |
+
# ===== 4. ํ๊ฐ =====
|
| 54 |
+
print("\n" + "="*60)
|
| 55 |
+
print("์์คํ
ํ๊ฐ")
|
| 56 |
+
print("="*60)
|
| 57 |
+
|
| 58 |
+
evaluator = RAGEvaluator(rag)
|
| 59 |
+
eval_results = evaluator.evaluate()
|
| 60 |
+
|
| 61 |
+
print("\n" + "="*60)
|
| 62 |
+
print("โ
๋ชจ๋ ์์
์๋ฃ")
|
| 63 |
+
print("="*60)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
main()
|
src/retriever/retriever.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_chroma import Chroma
|
| 2 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 3 |
+
from langsmith import traceable
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
from rank_bm25 import BM25Okapi
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sentence_transformers import CrossEncoder
|
| 9 |
+
|
| 10 |
+
from src.utils.config import RAGConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RAGRetriever:
|
| 14 |
+
"""RAG ๊ฒ์ ์์คํ
(Hybrid Search + Re-ranker)"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, config: RAGConfig = None):
|
| 17 |
+
self.config = config or RAGConfig()
|
| 18 |
+
self.vectorstore = None
|
| 19 |
+
self.embeddings = None
|
| 20 |
+
|
| 21 |
+
self._initialize_embeddings()
|
| 22 |
+
self._create_vectorstore()
|
| 23 |
+
self._initialize_bm25()
|
| 24 |
+
self._initialize_reranker()
|
| 25 |
+
|
| 26 |
+
def _initialize_embeddings(self):
|
| 27 |
+
"""์๋ฒ ๋ฉ ๋ชจ๋ธ ์ด๊ธฐํ"""
|
| 28 |
+
os.environ["OPENAI_API_KEY"] = self.config.OPENAI_API_KEY
|
| 29 |
+
|
| 30 |
+
self.embeddings = OpenAIEmbeddings(
|
| 31 |
+
model=self.config.EMBEDDING_MODEL_NAME
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def _create_vectorstore(self):
|
| 35 |
+
"""๊ธฐ์กด ๋ฒกํฐ์คํ ์ด ๋ก๋"""
|
| 36 |
+
self.vectorstore = Chroma(
|
| 37 |
+
embedding_function=self.embeddings,
|
| 38 |
+
persist_directory=self.config.DB_DIRECTORY,
|
| 39 |
+
collection_name=self.config.COLLECTION_NAME
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def _initialize_bm25(self):
|
| 43 |
+
"""BM25 ์ธ๋ฑ์ค ์์ฑ"""
|
| 44 |
+
all_docs = self.vectorstore.get()
|
| 45 |
+
|
| 46 |
+
self.doc_texts = all_docs['documents']
|
| 47 |
+
self.doc_ids = all_docs['ids']
|
| 48 |
+
self.doc_metadatas = all_docs['metadatas']
|
| 49 |
+
|
| 50 |
+
self.content_to_id = {text: doc_id for text, doc_id in zip(self.doc_texts, self.doc_ids)}
|
| 51 |
+
|
| 52 |
+
tokenized_docs = [doc.split() for doc in self.doc_texts]
|
| 53 |
+
self.bm25 = BM25Okapi(tokenized_docs)
|
| 54 |
+
|
| 55 |
+
print(f"โ
BM25 ์ธ๋ฑ์ค ์์ฑ ์๋ฃ: {len(self.doc_texts)}๊ฐ ๋ฌธ์")
|
| 56 |
+
|
| 57 |
+
def _initialize_reranker(self):
|
| 58 |
+
"""Re-ranker ์ด๊ธฐํ"""
|
| 59 |
+
self.reranker = CrossEncoder('BAAI/bge-reranker-base')
|
| 60 |
+
print("โ
Re-ranker ์ด๊ธฐํ ์๋ฃ (bge-reranker-base)")
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def _min_max_normalize(scores):
|
| 64 |
+
"""0~1 ๋ฒ์๋ก ์ ๊ทํ"""
|
| 65 |
+
scores = np.array(scores)
|
| 66 |
+
min_score = scores.min()
|
| 67 |
+
max_score = scores.max()
|
| 68 |
+
|
| 69 |
+
if max_score == min_score:
|
| 70 |
+
return np.full_like(scores, 0.5, dtype=float)
|
| 71 |
+
|
| 72 |
+
return (scores - min_score) / (max_score - min_score)
|
| 73 |
+
|
| 74 |
+
def _find_doc_id_by_content(self, content):
|
| 75 |
+
"""๋ฌธ์ content๋ก ID ์ฐพ๊ธฐ"""
|
| 76 |
+
return self.content_to_id.get(content, None)
|
| 77 |
+
|
| 78 |
+
def _rerank(self, query, documents, top_k):
|
| 79 |
+
"""
|
| 80 |
+
๊ฒ์ ๊ฒฐ๊ณผ ์ฌ์ ๋ ฌ
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
query: ๊ฒ์ ์ฟผ๋ฆฌ
|
| 84 |
+
documents: hybrid_search ๊ฒฐ๊ณผ ๋ฆฌ์คํธ
|
| 85 |
+
top_k: ์ต์ข
๋ฐํํ ๋ฌธ์ ์
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
์ฌ์ ๋ ฌ๋ ์์ k๊ฐ ๋ฌธ์
|
| 89 |
+
"""
|
| 90 |
+
if len(documents) == 0:
|
| 91 |
+
return []
|
| 92 |
+
|
| 93 |
+
# 1. (query, document) ์ ์์ฑ
|
| 94 |
+
pairs = [[query, doc['content']] for doc in documents]
|
| 95 |
+
|
| 96 |
+
# 2. CrossEncoder๋ก ์ ์ ๊ณ์ฐ
|
| 97 |
+
scores = self.reranker.predict(pairs)
|
| 98 |
+
|
| 99 |
+
# 3. ์ ์๋ฅผ ๋ฌธ์์ ์ถ๊ฐ
|
| 100 |
+
for i, doc in enumerate(documents):
|
| 101 |
+
doc['rerank_score'] = float(scores[i])
|
| 102 |
+
|
| 103 |
+
# 4. ์ ๋ ฌ ๋ฐ ๋ฐํ
|
| 104 |
+
sorted_docs = sorted(documents,
|
| 105 |
+
key=lambda x: x['rerank_score'],
|
| 106 |
+
reverse=True)
|
| 107 |
+
|
| 108 |
+
return sorted_docs[:top_k]
|
| 109 |
+
|
| 110 |
+
@traceable(
|
| 111 |
+
name="RAG_Hybrid_Search",
|
| 112 |
+
metadata={"component": "retriever", "version": "2.0"}
|
| 113 |
+
)
|
| 114 |
+
def hybrid_search(self, query, top_k=None, alpha=0.5):
|
| 115 |
+
"""
|
| 116 |
+
Hybrid Search: BM25 + ์๋ฒ ๋ฉ ๊ฒฐํฉ
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
query: ๊ฒ์ ์ฟผ๋ฆฌ
|
| 120 |
+
top_k: ๋ฐํํ ๋ฌธ์ ์
|
| 121 |
+
alpha: ์๋ฒ ๋ฉ ๊ฐ์ค์น (0~1)
|
| 122 |
+
"""
|
| 123 |
+
start_time = time.time()
|
| 124 |
+
|
| 125 |
+
if top_k is None:
|
| 126 |
+
top_k = self.config.DEFAULT_TOP_K
|
| 127 |
+
|
| 128 |
+
# 1. BM25 ๊ฒ์
|
| 129 |
+
tokenized_query = query.split()
|
| 130 |
+
bm25_scores = self.bm25.get_scores(tokenized_query)
|
| 131 |
+
bm25_normalized = self._min_max_normalize(bm25_scores)
|
| 132 |
+
|
| 133 |
+
# 2. ์๋ฒ ๋ฉ ๊ฒ์
|
| 134 |
+
embedding_results = self.vectorstore.similarity_search_with_score(
|
| 135 |
+
query, k=min(top_k * 3, len(self.doc_texts))
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# 3. ์๋ฒ ๋ฉ ์ ์ ์ ๊ทํ
|
| 139 |
+
embedding_scores_raw = {}
|
| 140 |
+
for doc, distance in embedding_results:
|
| 141 |
+
doc_id = self._find_doc_id_by_content(doc.page_content)
|
| 142 |
+
if doc_id:
|
| 143 |
+
embedding_scores_raw[doc_id] = 1 / (1 + distance)
|
| 144 |
+
|
| 145 |
+
if embedding_scores_raw:
|
| 146 |
+
embed_values = np.array(list(embedding_scores_raw.values()))
|
| 147 |
+
embed_normalized = self._min_max_normalize(embed_values)
|
| 148 |
+
embedding_scores = dict(zip(embedding_scores_raw.keys(), embed_normalized))
|
| 149 |
+
else:
|
| 150 |
+
embedding_scores = {}
|
| 151 |
+
|
| 152 |
+
# 4. ํ์ด๋ธ๋ฆฌ๋ ์ ์ ๊ณ์ฐ
|
| 153 |
+
hybrid_scores = {}
|
| 154 |
+
for i, doc_id in enumerate(self.doc_ids):
|
| 155 |
+
bm25_score = bm25_normalized[i]
|
| 156 |
+
embed_score = embedding_scores.get(doc_id, 0)
|
| 157 |
+
hybrid_scores[doc_id] = (1 - alpha) * bm25_score + alpha * embed_score
|
| 158 |
+
|
| 159 |
+
# 5. ์ ๋ ฌ ๋ฐ ์์ k๊ฐ ์ ํ
|
| 160 |
+
sorted_ids = sorted(hybrid_scores.keys(),
|
| 161 |
+
key=lambda x: hybrid_scores[x],
|
| 162 |
+
reverse=True)
|
| 163 |
+
top_ids = sorted_ids[:top_k]
|
| 164 |
+
|
| 165 |
+
# 6. ๊ฒฐ๊ณผ ํฌ๋งทํ
|
| 166 |
+
formatted_results = []
|
| 167 |
+
for doc_id in top_ids:
|
| 168 |
+
idx = self.doc_ids.index(doc_id)
|
| 169 |
+
formatted_results.append({
|
| 170 |
+
'content': self.doc_texts[idx],
|
| 171 |
+
'metadata': self.doc_metadatas[idx],
|
| 172 |
+
'hybrid_score': hybrid_scores[doc_id],
|
| 173 |
+
'bm25_score': float(bm25_normalized[idx]),
|
| 174 |
+
'embed_score': embedding_scores.get(doc_id, 0),
|
| 175 |
+
'filename': self.doc_metadatas[idx].get('ํ์ผ๋ช
', 'N/A'),
|
| 176 |
+
'organization': self.doc_metadatas[idx].get('๋ฐ์ฃผ ๊ธฐ๊ด', 'N/A')
|
| 177 |
+
})
|
| 178 |
+
|
| 179 |
+
end_time = time.time()
|
| 180 |
+
print(f"๐ Hybrid ๊ฒ์ ์๋ฃ: {len(formatted_results)}๊ฐ (alpha={alpha}, {end_time-start_time:.3f}์ด)")
|
| 181 |
+
return formatted_results
|
| 182 |
+
|
| 183 |
+
@traceable(
|
| 184 |
+
name="RAG_Hybrid_Search_Rerank",
|
| 185 |
+
metadata={"component": "retriever", "version": "3.0"}
|
| 186 |
+
)
|
| 187 |
+
def hybrid_search_with_rerank(self, query, top_k=None, alpha=0.5, rerank_candidates=None):
|
| 188 |
+
"""
|
| 189 |
+
Hybrid Search + Re-ranking
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
query: ๊ฒ์ ์ฟผ๋ฆฌ
|
| 193 |
+
top_k: ์ต์ข
๋ฐํํ ๋ฌธ์ ์
|
| 194 |
+
alpha: BM25/์๋ฒ ๋ฉ ๊ฐ์ค์น
|
| 195 |
+
rerank_candidates: Re-rankํ ํ๋ณด ์ (None์ด๋ฉด top_k * 3)
|
| 196 |
+
"""
|
| 197 |
+
start_time = time.time()
|
| 198 |
+
|
| 199 |
+
if top_k is None:
|
| 200 |
+
top_k = self.config.DEFAULT_TOP_K
|
| 201 |
+
|
| 202 |
+
if rerank_candidates is None:
|
| 203 |
+
rerank_candidates = top_k * 3
|
| 204 |
+
|
| 205 |
+
# 1. Hybrid Search๋ก ํ๋ณด ๋ฌธ์ ๊ฐ์ ธ์ค๊ธฐ
|
| 206 |
+
candidates = self.hybrid_search(query, top_k=rerank_candidates, alpha=alpha)
|
| 207 |
+
|
| 208 |
+
# 2. Re-ranking
|
| 209 |
+
if len(candidates) > 0:
|
| 210 |
+
results = self._rerank(query, candidates, top_k)
|
| 211 |
+
else:
|
| 212 |
+
results = []
|
| 213 |
+
|
| 214 |
+
end_time = time.time()
|
| 215 |
+
print(f"๐ Re-ranking ์๋ฃ: {len(candidates)}๊ฐ โ {len(results)}๊ฐ ({end_time-start_time:.3f}์ด)")
|
| 216 |
+
|
| 217 |
+
return results
|
| 218 |
+
|
| 219 |
+
def search_with_mode(self, query, top_k=None, mode="hybrid_rerank", alpha=0.5):
|
| 220 |
+
"""๊ฒ์ ๋ชจ๋ ์ ํ"""
|
| 221 |
+
if mode == "embedding":
|
| 222 |
+
return self.search(query, top_k)
|
| 223 |
+
elif mode == "bm25":
|
| 224 |
+
return self.hybrid_search(query, top_k, alpha=0.0)
|
| 225 |
+
elif mode == "hybrid":
|
| 226 |
+
return self.hybrid_search(query, top_k, alpha=alpha)
|
| 227 |
+
elif mode == "hybrid_rerank":
|
| 228 |
+
return self.hybrid_search_with_rerank(query, top_k, alpha)
|
| 229 |
+
else:
|
| 230 |
+
raise ValueError(f"Unknown mode: {mode}")
|
| 231 |
+
|
| 232 |
+
@traceable(
|
| 233 |
+
name="RAG_Retriever_Search",
|
| 234 |
+
metadata={"component": "retriever", "version": "1.0"}
|
| 235 |
+
)
|
| 236 |
+
def search(self, query: str, top_k: int = None, filter_metadata: dict = None):
|
| 237 |
+
"""
|
| 238 |
+
์ ์ฌ ๋ฌธ์ ๊ฒ์ (์๋ฒ ๋ฉ ๊ธฐ๋ฐ)
|
| 239 |
+
"""
|
| 240 |
+
start_time = time.time()
|
| 241 |
+
if top_k is None:
|
| 242 |
+
top_k = self.config.DEFAULT_TOP_K
|
| 243 |
+
|
| 244 |
+
if filter_metadata:
|
| 245 |
+
results = self.vectorstore.similarity_search_with_score(
|
| 246 |
+
query, k=top_k, filter=filter_metadata
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
results = self.vectorstore.similarity_search_with_score(
|
| 250 |
+
query, k=top_k
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
formatted_results = []
|
| 254 |
+
for doc, score in results:
|
| 255 |
+
formatted_results.append({
|
| 256 |
+
'content': doc.page_content,
|
| 257 |
+
'metadata': doc.metadata,
|
| 258 |
+
'distance': score,
|
| 259 |
+
'relevance_score': 1 - score,
|
| 260 |
+
'filename': doc.metadata.get('ํ์ผ๋ช
', 'N/A'),
|
| 261 |
+
'organization': doc.metadata.get('๋ฐ์ฃผ ๊ธฐ๊ด', 'N/A')
|
| 262 |
+
})
|
| 263 |
+
|
| 264 |
+
end_time = time.time()
|
| 265 |
+
print(f"๐ ๊ฒ์ ์๋ฃ: {len(results)}๊ฐ ({end_time-start_time:.3f}์ด)")
|
| 266 |
+
return formatted_results
|
| 267 |
+
|
| 268 |
+
def search_with_rerank(self, query, top_k=None, rerank_candidates=None):
|
| 269 |
+
"""
|
| 270 |
+
์๋ฒ ๋ฉ ๊ฒ์ + Re-ranking
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
query: ๊ฒ์ ์ฟผ๋ฆฌ
|
| 274 |
+
top_k: ์ต์ข
๋ฐํํ ๋ฌธ์ ์
|
| 275 |
+
rerank_candidates: Re-rankํ ํ๋ณด ์
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
์ฌ์ ๋ ฌ๋ ๋ฌธ์ ๋ฆฌ์คํธ
|
| 279 |
+
"""
|
| 280 |
+
start_time = time.time()
|
| 281 |
+
|
| 282 |
+
if top_k is None:
|
| 283 |
+
top_k = self.config.DEFAULT_TOP_K
|
| 284 |
+
|
| 285 |
+
if rerank_candidates is None:
|
| 286 |
+
rerank_candidates = top_k * 3
|
| 287 |
+
|
| 288 |
+
# 1. ์๋ฒ ๋ฉ ๊ฒ์์ผ๋ก ํ๋ณด ๊ฐ์ ธ์ค๊ธฐ
|
| 289 |
+
candidates = self.search(query, top_k=rerank_candidates)
|
| 290 |
+
|
| 291 |
+
# 2. Re-ranking
|
| 292 |
+
if len(candidates) > 0:
|
| 293 |
+
results = self._rerank(query, candidates, top_k)
|
| 294 |
+
else:
|
| 295 |
+
results = []
|
| 296 |
+
|
| 297 |
+
end_time = time.time()
|
| 298 |
+
print(f"๐ Embedding + Re-ranking ์๋ฃ: {len(candidates)}๊ฐ โ {len(results)}๊ฐ ({end_time-start_time:.3f}์ด)")
|
| 299 |
+
|
| 300 |
+
return results
|
| 301 |
+
|
| 302 |
+
def search_by_organization(self, query: str, organization: str, top_k: int = None):
|
| 303 |
+
"""ํน์ ๋ฐ์ฃผ๊ธฐ๊ด๋ง ๊ฒ์"""
|
| 304 |
+
return self.search(
|
| 305 |
+
query, top_k=top_k, filter_metadata={'๋ฐ์ฃผ ๊ธฐ๊ด': organization}
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def get_retriever(self):
|
| 309 |
+
"""LangChain ์ฒด์ธ์ฉ Retriever ๋ฐํ"""
|
| 310 |
+
return self.vectorstore.as_retriever(
|
| 311 |
+
search_type="similarity",
|
| 312 |
+
search_kwargs={"k": self.config.DEFAULT_TOP_K}
|
| 313 |
+
)
|
src/utils/config.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Config:
|
| 6 |
+
"""RAG ์์คํ
ํตํฉ ์ค์ ํด๋์ค"""
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
# .env ํ์ผ ๋ก๋
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
# ===== API ํค =====
|
| 13 |
+
self.OPENAI_API_KEY = self._get_api_key()
|
| 14 |
+
|
| 15 |
+
# ===== ๊ฒฝ๋ก ์ค์ =====
|
| 16 |
+
# ์ ์ฒ๋ฆฌ
|
| 17 |
+
self.META_CSV_PATH = "./data/data_list.csv"
|
| 18 |
+
self.BASE_FOLDER_PATH = "./data/files/"
|
| 19 |
+
self.OUTPUT_CHUNKS_PATH = "./data/rag_chunks_final.csv"
|
| 20 |
+
|
| 21 |
+
# RAG - ํ๊ฒฝ๋ณ์ ์ฐ์ , ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ
|
| 22 |
+
self.RAG_INPUT_PATH = "./data/rag_chunks_final.csv"
|
| 23 |
+
self.DB_DIRECTORY = os.getenv("CHROMA_DB_PATH", "./chroma_db")
|
| 24 |
+
|
| 25 |
+
# ===== ์ ์ฒ๋ฆฌ ์ค์ =====
|
| 26 |
+
self.CHUNK_SIZE = 1000
|
| 27 |
+
self.CHUNK_OVERLAP = 200
|
| 28 |
+
self.SEPARATORS = ["\n\n", "\n", " ", ""]
|
| 29 |
+
self.MIN_TEXT_LENGTH = 100
|
| 30 |
+
|
| 31 |
+
# ===== ์๋ฒ ๋ฉ ์ค์ =====
|
| 32 |
+
self.EMBEDDING_MODEL_NAME = "text-embedding-3-small"
|
| 33 |
+
self.BATCH_SIZE = 50
|
| 34 |
+
self.MAX_TOKENS_PER_BATCH = 250000
|
| 35 |
+
|
| 36 |
+
# ์ฒญํฌ ๊ฒ์ฆ ๊ธฐ์ค
|
| 37 |
+
self.MIN_CHUNK_LENGTH = 10
|
| 38 |
+
self.MAX_CHUNK_LENGTH = 10000
|
| 39 |
+
|
| 40 |
+
# ===== ๋ฒกํฐ DB ์ค์ =====
|
| 41 |
+
self.COLLECTION_NAME = "rag_documents"
|
| 42 |
+
|
| 43 |
+
# ===== ๊ฒ์ ์ค์ =====
|
| 44 |
+
self.DEFAULT_TOP_K = 10
|
| 45 |
+
self.DEFAULT_ALPHA = 0.5
|
| 46 |
+
self.DEFAULT_SEARCH_MODE = "hybrid_rerank"
|
| 47 |
+
|
| 48 |
+
# ===== LLM ์ค์ =====
|
| 49 |
+
self.LLM_MODEL_NAME = "gpt-4o-mini"
|
| 50 |
+
self.DEFAULT_TEMPERATURE = 0.0
|
| 51 |
+
self.DEFAULT_MAX_TOKENS = 1000
|
| 52 |
+
|
| 53 |
+
# ์์คํ
ํ๋กฌํํธ
|
| 54 |
+
self.SYSTEM_PROMPT = "๋น์ ์ RFP(์ ์์์ฒญ์) ๋ถ์ ๋ฐ ์์ฝ ์ ๋ฌธ๊ฐ์
๋๋ค."
|
| 55 |
+
|
| 56 |
+
# ===== GGUF ๋ก์ปฌ ๋ชจ๋ธ ์ค์ =====
|
| 57 |
+
# Model Hub ์ฌ์ฉ ์ฌ๋ถ (ํ๊ฒฝ๋ณ์ ์ฐ์ )
|
| 58 |
+
self.USE_MODEL_HUB = os.getenv("USE_MODEL_HUB", "true").lower() == "true"
|
| 59 |
+
|
| 60 |
+
# Hugging Face Model Hub ์ค์
|
| 61 |
+
# 1. QLoRA ๋ชจ๋ธ (Fine-tuned) - ๊ธฐ์กด ์๋น์ค์ฉ
|
| 62 |
+
self.MODEL_HUB_REPO = os.getenv(
|
| 63 |
+
"MODEL_HUB_REPO",
|
| 64 |
+
"Dongjin1203/RFP_Documents_chatbot"
|
| 65 |
+
)
|
| 66 |
+
self.MODEL_HUB_FILENAME = os.getenv(
|
| 67 |
+
"MODEL_HUB_FILENAME",
|
| 68 |
+
"Llama-3-Open-Ko-8B.Q4_K_M.gguf"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# 2. Base ๋ชจ๋ธ (PEFT ์์) - ๋น๊ต ์คํ์ฉ
|
| 72 |
+
self.BASE_MODEL_HUB_REPO = os.getenv(
|
| 73 |
+
"BASE_MODEL_HUB_REPO",
|
| 74 |
+
"beomi/Llama-3-Open-Ko-8B-gguf"
|
| 75 |
+
)
|
| 76 |
+
self.BASE_MODEL_HUB_FILENAME = os.getenv(
|
| 77 |
+
"BASE_MODEL_HUB_FILENAME",
|
| 78 |
+
"ggml-model-Q4_K_M.gguf"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# ๊ณตํต ์บ์ ๋๋ ํ ๋ฆฌ
|
| 82 |
+
self.MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", ".cache/models")
|
| 83 |
+
|
| 84 |
+
# ๋ก์ปฌ ๊ฒฝ๋ก (USE_MODEL_HUB=false์ธ ๊ฒฝ์ฐ)
|
| 85 |
+
self.GGUF_MODEL_PATH = os.getenv("GGUF_MODEL_PATH", ".cache/models/Llama-3-Open-Ko-8B.Q4_K_M.gguf")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# GGUF GPU ์ค์ (T4 Medium ์ต์ ํ - 8B ๋ชจ๋ธ์ฉ)
|
| 89 |
+
self.GGUF_N_GPU_LAYERS = int(os.getenv("GGUF_N_GPU_LAYERS", "35")) # T4์์ 8B ๋ชจ๋ธ ์ ์ฒด๋ฅผ GPU์ ๋ก๋
|
| 90 |
+
self.GGUF_N_CTX = int(os.getenv("GGUF_N_CTX", "2048")) # ์ปจํ
์คํธ ๊ธธ์ด
|
| 91 |
+
self.GGUF_N_THREADS = int(os.getenv("GGUF_N_THREADS", "4")) # CPU ์ค๋ ๋ (GPU ์ฌ์ฉ ์ ๋ฎ๊ฒ)
|
| 92 |
+
self.GGUF_MAX_NEW_TOKENS = int(os.getenv("GGUF_MAX_NEW_TOKENS", "512")) # ์ต๋ ์์ฑ ํ ํฐ
|
| 93 |
+
self.GGUF_TEMPERATURE = float(os.getenv("GGUF_TEMPERATURE", "0.7")) # ์์ฑ ๋ค์์ฑ
|
| 94 |
+
self.GGUF_TOP_P = float(os.getenv("GGUF_TOP_P", "0.9")) # Nucleus sampling
|
| 95 |
+
|
| 96 |
+
def _get_api_key(self) -> str:
|
| 97 |
+
"""ํ๊ฒฝ๋ณ์์์ API ํค ๋ก๋"""
|
| 98 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 99 |
+
|
| 100 |
+
if not api_key:
|
| 101 |
+
raise ValueError(
|
| 102 |
+
"OPENAI_API_KEY๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.\n"
|
| 103 |
+
"ํ๋ก์ ํธ ๋ฃจํธ์ .env ํ์ผ์ ๋ง๋ค๊ณ OPENAI_API_KEY=your-key ๋ฅผ ์ถ๊ฐํ์ธ์."
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return api_key
|
| 107 |
+
|
| 108 |
+
def validate_preprocess(self):
|
| 109 |
+
"""์ ์ฒ๋ฆฌ ์ค์ ์ ํจ์ฑ ๊ฒ์ฌ"""
|
| 110 |
+
if not os.path.exists(self.META_CSV_PATH):
|
| 111 |
+
raise FileNotFoundError(
|
| 112 |
+
f"๋ฉํ CSV ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {self.META_CSV_PATH}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if not os.path.exists(self.BASE_FOLDER_PATH):
|
| 116 |
+
raise FileNotFoundError(
|
| 117 |
+
f"ํ์ผ ํด๋๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค: {self.BASE_FOLDER_PATH}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
output_dir = os.path.dirname(self.OUTPUT_CHUNKS_PATH)
|
| 121 |
+
if output_dir:
|
| 122 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 123 |
+
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
def validate_rag(self):
|
| 127 |
+
"""RAG ์ค์ ์ ํจ์ฑ ๊ฒ์ฌ"""
|
| 128 |
+
if not self.OPENAI_API_KEY:
|
| 129 |
+
raise ValueError("OPENAI_API_KEY๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค")
|
| 130 |
+
|
| 131 |
+
return True
|
| 132 |
+
|
| 133 |
+
def validate_gguf(self):
|
| 134 |
+
"""GGUF ์ค์ ์ ํจ์ฑ ๊ฒ์ฌ"""
|
| 135 |
+
if not self.USE_MODEL_HUB:
|
| 136 |
+
# ๋ก์ปฌ ํ์ผ ์ฌ์ฉ ์ ๊ฒฝ๋ก ํ์ธ
|
| 137 |
+
if not os.path.exists(self.GGUF_MODEL_PATH):
|
| 138 |
+
print(f"โ ๏ธ ๊ฒฝ๊ณ : GGUF ๋ชจ๋ธ ํ์ผ์ด ์์ต๋๋ค: {self.GGUF_MODEL_PATH}")
|
| 139 |
+
print(f" USE_MODEL_HUB=true๋ก ์ค์ ํ์ฌ ์๋ ๋ค์ด๋ก๋ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ์ ์ค๋นํ์ธ์.")
|
| 140 |
+
|
| 141 |
+
# GPU ๋ ์ด์ด ์ค์ ํ์ธ
|
| 142 |
+
if self.GGUF_N_GPU_LAYERS > 0:
|
| 143 |
+
print(f"โ
GPU ๊ฐ์ ํ์ฑํ: {self.GGUF_N_GPU_LAYERS}๊ฐ ๋ ์ด์ด")
|
| 144 |
+
else:
|
| 145 |
+
print(f"โ ๏ธ CPU ์ ์ฉ ๋ชจ๋ (n_gpu_layers=0)")
|
| 146 |
+
|
| 147 |
+
return True
|
| 148 |
+
|
| 149 |
+
def validate_all(self):
|
| 150 |
+
"""์ ์ฒด ์ค์ ์ ํจ์ฑ ๊ฒ์ฌ"""
|
| 151 |
+
self.validate_preprocess()
|
| 152 |
+
self.validate_rag()
|
| 153 |
+
self.validate_gguf()
|
| 154 |
+
return True
|
| 155 |
+
|
| 156 |
+
def validate(self):
|
| 157 |
+
"""์ค์ ์ ํจ์ฑ ๊ฒ์ฌ (ํ์ ํธํ์ฑ)"""
|
| 158 |
+
return self.validate_preprocess()
|
| 159 |
+
|
| 160 |
+
def print_gguf_config(self):
|
| 161 |
+
"""GGUF ์ค์ ์ถ๋ ฅ (๋๋ฒ๊น
์ฉ)"""
|
| 162 |
+
print("\n" + "="*50)
|
| 163 |
+
print("GGUF ๋ชจ๋ธ ์ค์ ")
|
| 164 |
+
print("="*50)
|
| 165 |
+
print(f"Model Hub ์ฌ์ฉ: {self.USE_MODEL_HUB}")
|
| 166 |
+
|
| 167 |
+
if self.USE_MODEL_HUB:
|
| 168 |
+
print(f"\n[QLoRA ๋ชจ๋ธ]")
|
| 169 |
+
print(f" Hub Repo: {self.MODEL_HUB_REPO}")
|
| 170 |
+
print(f" Hub ํ์ผ๋ช
: {self.MODEL_HUB_FILENAME}")
|
| 171 |
+
|
| 172 |
+
print(f"\n[Base ๋ชจ๋ธ]")
|
| 173 |
+
print(f" Hub Repo: {self.BASE_MODEL_HUB_REPO}")
|
| 174 |
+
print(f" Hub ํ์ผ๋ช
: {self.BASE_MODEL_HUB_FILENAME}")
|
| 175 |
+
|
| 176 |
+
print(f"\n[๊ณตํต]")
|
| 177 |
+
print(f" ์บ์ ๋๋ ํ ๋ฆฌ: {self.MODEL_CACHE_DIR}")
|
| 178 |
+
else:
|
| 179 |
+
print(f"๋ก์ปฌ ๊ฒฝ๋ก: {self.GGUF_MODEL_PATH}")
|
| 180 |
+
print(f"\nGPU ์ค์ :")
|
| 181 |
+
print(f" - GPU ๋ ์ด์ด: {self.GGUF_N_GPU_LAYERS}")
|
| 182 |
+
print(f" - ์ปจํ
์คํธ: {self.GGUF_N_CTX}")
|
| 183 |
+
print(f" - ์ค๋ ๋: {self.GGUF_N_THREADS}")
|
| 184 |
+
print(f"\n์์ฑ ์ค์ :")
|
| 185 |
+
print(f" - Max Tokens: {self.GGUF_MAX_NEW_TOKENS}")
|
| 186 |
+
print(f" - Temperature: {self.GGUF_TEMPERATURE}")
|
| 187 |
+
print(f" - Top-P: {self.GGUF_TOP_P}")
|
| 188 |
+
print("="*50 + "\n")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ํ์ ํธํ์ฑ์ ์ํ ๋ณ์นญ
|
| 192 |
+
PreprocessConfig = Config
|
| 193 |
+
RAGConfig = Config
|