fix bugs
Browse files- .gitignore +4 -0
- Dockerfile +29 -10
- README.md +129 -33
- adapters/qwen-coder-pauq-lora/README.md +133 -196
- data/demo/sales.sqlite-journal +0 -0
- evaluate_pauq.py +271 -0
- scripts/run_app.py +113 -0
- scripts/smoke_local.py +272 -0
- src/api/main.py +157 -13
- src/api/schemas.py +80 -0
- src/data/prompt.py +70 -8
- src/data/schema.py +17 -66
- src/data/schema_provider.py +190 -0
- src/db/connector.py +43 -30
- src/db/executor.py +29 -28
- src/models/inference.py +40 -14
- src/models/postprocess.py +117 -20
- src/streamlit_app.py +10 -7
- streamlit_app.py +253 -257
- tests/test_db.py +132 -0
- tests/test_postprocess.py +142 -10
- tests/test_prompt.py +57 -2
- tests/test_schema_provider.py +182 -0
- tests/test_vocabulary.py +123 -0
.gitignore
CHANGED
|
@@ -68,3 +68,7 @@ htmlcov/
|
|
| 68 |
# Outputs
|
| 69 |
outputs/
|
| 70 |
results/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# Outputs
|
| 69 |
outputs/
|
| 70 |
results/
|
| 71 |
+
|
| 72 |
+
# ΠΠΎΠΊΡΠΌΠ΅Π½ΡΡ ΠΠΠ
|
| 73 |
+
*.docx
|
| 74 |
+
*.pptx
|
Dockerfile
CHANGED
|
@@ -1,24 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
FROM python:3.12-slim
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
RUN apt-get update && apt-get install -y \
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
COPY requirements.txt ./
|
| 12 |
-
RUN pip3 install -r requirements.txt
|
| 13 |
|
| 14 |
COPY . .
|
| 15 |
|
| 16 |
-
|
| 17 |
-
ENV
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
EXPOSE 7860
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
ENTRYPOINT ["
|
|
|
|
| 1 |
+
# ΠΠ±ΡΠ°Π· Π΄Π»Ρ HuggingFace Spaces (Docker SDK).
|
| 2 |
+
# ΠΠ°ΠΏΡΡΠΊΠ°Π΅Ρ ΠΎΠ±Π° ΠΏΡΠΎΡΠ΅ΡΡΠ° Ru2SQL Π² ΠΎΠ΄Π½ΠΎΠΌ ΠΊΠΎΠ½ΡΠ΅ΠΉΠ½Π΅ΡΠ΅:
|
| 3 |
+
# - FastAPI Π½Π° 127.0.0.1:8000 (Π²Π½ΡΡΡΠ΅Π½Π½ΠΈΠΉ)
|
| 4 |
+
# - Streamlit Π½Π° 0.0.0.0:7860 (Π²Π½Π΅ΡΠ½ΠΈΠΉ, HF Spaces ΡΠΈΡΠ°Π΅Ρ Ρ ΡΡΠΎΠ³ΠΎ ΠΏΠΎΡΡΠ°)
|
| 5 |
+
# ΠΡΠΊΠ΅ΡΡΡΠΈΡΡΠ΅Ρ ΠΈΡ
scripts/run_app.py.
|
| 6 |
+
|
| 7 |
FROM python:3.12-slim
|
| 8 |
|
| 9 |
WORKDIR /app
|
| 10 |
|
| 11 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 12 |
+
build-essential \
|
| 13 |
+
curl \
|
| 14 |
+
git \
|
| 15 |
&& rm -rf /var/lib/apt/lists/*
|
| 16 |
|
| 17 |
+
# Π‘Π½Π°ΡΠ°Π»Π° ΡΡΠ°Π²ΠΈΠΌ torch Ρ ΠΏΠΎΠ΄Π΄Π΅ΡΠΆΠΊΠΎΠΉ CUDA 12.1 β ΡΡΠΎ Π½ΡΠΆΠ½ΠΎ Π΄Π»Ρ T4 Small Π½Π° HF.
|
| 18 |
+
# ΠΠ° CPU-ΠΎΠΊΡΡΠΆΠ΅Π½ΠΈΠΈ ΡΡΠΎΡ ΠΆΠ΅ torch Π°Π²ΡΠΎΠΌΠ°ΡΠΈΡΠ΅ΡΠΊΠΈ ΡΠ°Π±ΠΎΡΠ°Π΅Ρ ΡΠ΅ΡΠ΅Π· CPU-Π±ΡΠΊΠ΅Π½Π΄.
|
| 19 |
+
RUN pip3 install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu121
|
| 20 |
+
|
| 21 |
COPY requirements.txt ./
|
| 22 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
| 23 |
|
| 24 |
COPY . .
|
| 25 |
|
| 26 |
+
# ΠΠ΅ΡΠΎΠ»ΡΡ Π΄Π»Ρ HF Spaces. ΠΠ΅ΡΠ΅ΠΏΠΈΡΡΠ²Π°ΡΡΡΡ Variables/Secrets Π² Π½Π°ΡΡΡΠΎΠΉΠΊΠ°Ρ
Space.
|
| 27 |
+
ENV LORA_ADAPTER_PATH=Tyycha/qwen-coder-pauq-lora \
|
| 28 |
+
BASE_MODEL_NAME=Qwen/Qwen2.5-Coder-3B-Instruct \
|
| 29 |
+
DEVICE=cuda \
|
| 30 |
+
API_HOST=127.0.0.1 \
|
| 31 |
+
API_PORT=8000 \
|
| 32 |
+
STREAMLIT_HOST=0.0.0.0 \
|
| 33 |
+
STREAMLIT_PORT=7860 \
|
| 34 |
+
RU2SQL_API_URL=http://127.0.0.1:8000
|
| 35 |
|
| 36 |
EXPOSE 7860
|
| 37 |
|
| 38 |
+
# start-period=600s β Ρ ΠΏΠ΅ΡΠ²ΠΎΠ³ΠΎ Π·Π°ΠΏΡΡΠΊΠ° Π΅ΡΡΡ Π΄ΠΎ 10 ΠΌΠΈΠ½ΡΡ Π½Π° ΡΠΊΠ°ΡΠΈΠ²Π°Π½ΠΈΠ΅ ΠΌΠΎΠ΄Π΅Π»ΠΈ
|
| 39 |
+
# Ρ HuggingFace. ΠΠ°Π»ΡΡΠ΅ healthcheck ΠΎΠΏΡΠ°ΡΠΈΠ²Π°Π΅Ρ Streamlit ΠΊΠ°ΠΆΠ΄ΡΠ΅ 30 ΡΠ΅ΠΊΡΠ½Π΄.
|
| 40 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=600s --retries=3 \
|
| 41 |
+
CMD curl --fail http://localhost:7860/_stcore/health || exit 1
|
| 42 |
|
| 43 |
+
ENTRYPOINT ["python", "scripts/run_app.py"]
|
README.md
CHANGED
|
@@ -12,23 +12,43 @@ pinned: false
|
|
| 12 |
ΠΠ΅Π½Π΅ΡΠ°ΡΠΈΠ²Π½Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ Π΄Π»Ρ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΡ Π²ΠΎΠΏΡΠΎΡΠΎΠ² Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² SQL-Π·Π°ΠΏΡΠΎΡΡ.
|
| 13 |
ΠΡΠ°ΠΊΡΠΈΡΠ΅ΡΠΊΠ°Ρ ΡΠ°ΡΡΡ ΠΠΠ , Π½Π°ΠΏΡΠ°Π²Π»Π΅Π½ΠΈΠ΅ Β«ΠΡΠΎΠ³ΡΠ°ΠΌΠΌΠ½Π°Ρ ΠΈΠ½ΠΆΠ΅Π½Π΅ΡΠΈΡΒ», 4 ΠΊΡΡΡ.
|
| 14 |
|
| 15 |
-
**Π‘ΡΠ΅ΠΊ:** Python 3.10+, PyTorch, transformers, PEFT (LoRA), FastAPI, sqlglot.
|
| 16 |
**ΠΡΠ½ΠΎΠ²Π½Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ:** Qwen2.5-Coder-3B-Instruct, Π΄ΠΎΠΎΠ±ΡΡΠ΅Π½Π½Π°Ρ ΠΌΠ΅ΡΠΎΠ΄ΠΎΠΌ QLoRA Π½Π° Π΄Π°ΡΠ°ΡΠ΅ΡΠ΅ PAUQ.
|
| 17 |
**Π‘ΡΠ°Π²Π½Π΅Π½ΠΈΠ΅:** ruT5-base baseline + GigaChat API.
|
| 18 |
|
| 19 |
-
Π‘ΠΌ. `plan_VKR_text2sql_ru.md` Π΄Π»Ρ ΠΏΠΎΠ»Π½ΠΎΠ³ΠΎ ΠΏΠ»Π°Π½Π° ΡΠ°Π±ΠΎΡ
|
| 20 |
|
| 21 |
---
|
| 22 |
|
| 23 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
### 1. Π£ΡΡΠ°Π½ΠΎΠ²ΠΊΠ°
|
| 26 |
|
| 27 |
```bash
|
| 28 |
-
# Π£ΡΡΠ°Π½ΠΎΠ²ΠΈ uv (https://docs.astral.sh/uv/) Π΅ΡΠ»ΠΈ Π΅ΡΡ Π½Π΅Ρ
|
| 29 |
pip install uv
|
| 30 |
-
|
| 31 |
-
# ΠΠ»ΠΎΠ½ΠΈΡΡΠΉ ΡΠ΅ΠΏΠΎΠ·ΠΈΡΠΎΡΠΈΠΉ ΠΈ ΡΡΡΠ°Π½ΠΎΠ²ΠΈ Π·Π°Π²ΠΈΡΠΈΠΌΠΎΡΡΠΈ
|
| 32 |
git clone <ΡΠ²ΠΎΠΉ-ΡΠ΅ΠΏΠΎ> ru2sql
|
| 33 |
cd ru2sql
|
| 34 |
uv venv
|
|
@@ -44,36 +64,97 @@ copy .env.example .env # Windows
|
|
| 44 |
# cp .env.example .env # Linux/Mac
|
| 45 |
```
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
### 3.
|
| 50 |
|
| 51 |
```bash
|
| 52 |
-
|
| 53 |
-
# ΠΠ°ΡΠ΅ΠΌ ΡΠ°Π·Π»ΠΎΠΆΠΈ train.json/dev.json/test.json Π² data/pauq/
|
| 54 |
-
# ΠΈ SQLite-ΡΠ°ΠΉΠ»Ρ Π² data/databases/{db_id}/{db_id}.sqlite
|
| 55 |
```
|
| 56 |
|
| 57 |
-
|
| 58 |
|
|
|
|
|
|
|
|
|
|
| 59 |
```bash
|
| 60 |
-
|
| 61 |
```
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
### 5. ΠΠ°ΠΏΡΡΠΊ
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
```bash
|
| 69 |
uvicorn src.api.main:app --reload
|
| 70 |
# Swagger UI: http://127.0.0.1:8000/docs
|
| 71 |
```
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
###
|
| 77 |
|
| 78 |
```bash
|
| 79 |
curl -X POST http://127.0.0.1:8000/generate-sql \
|
|
@@ -90,7 +171,7 @@ curl -X POST http://127.0.0.1:8000/generate-sql \
|
|
| 90 |
|
| 91 |
Π¨Π°Π³ΠΈ:
|
| 92 |
1. ΠΡΠΊΡΠΎΠΉ `notebooks/kaggle_train_qwen_qlora.ipynb` Π½Π° kaggle.com.
|
| 93 |
-
2. Π Settings Π²ΡΠ±Π΅ΡΠΈ Accelerator: GPU T4 x1
|
| 94 |
3. Add-ons β Secrets β Π΄ΠΎΠ±Π°Π²Ρ `HF_TOKEN` ΠΈ `WANDB_API_KEY`.
|
| 95 |
4. ΠΠ°ΠΏΡΡΡΠΈ Π²ΡΠ΅ ΡΡΠ΅ΠΉΠΊΠΈ. Π’ΡΠ΅Π½ΠΈΡΠΎΠ²ΠΊΠ° ~4β6 ΡΠ°ΡΠΎΠ².
|
| 96 |
5. ΠΠΎ Π·Π°Π²Π΅ΡΡΠ΅Π½ΠΈΠΈ Π°Π΄Π°ΠΏΡΠ΅Ρ ΠΏΡΡΠΈΡΡΡ Π½Π° ΡΠ²ΠΎΠΉ ΠΏΡΠΈΠ²Π°ΡΠ½ΡΠΉ HF-ΡΠ΅ΠΏΠΎ.
|
|
@@ -109,32 +190,47 @@ curl -X POST http://127.0.0.1:8000/generate-sql \
|
|
| 109 |
|
| 110 |
```
|
| 111 |
ru2sql/
|
| 112 |
-
βββ pyproject.toml # Π·Π°Π²ΠΈΡΠΈΠΌΠΎΡΡΠΈ
|
| 113 |
βββ .env.example # ΡΠ°Π±Π»ΠΎΠ½ ΠΊΠΎΠ½ΡΠΈΠ³ΡΡΠ°ΡΠΈΠΈ
|
| 114 |
-
βββ plan_VKR_text2sql_ru.md # ΠΏΠ»Π°Π½ ΡΠ°Π±ΠΎΡ
|
| 115 |
βββ notebooks/
|
| 116 |
β βββ kaggle_train_qwen_qlora.ipynb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
βββ src/
|
| 118 |
β βββ config.py # Π½Π°ΡΡΡΠΎΠΉΠΊΠΈ ΡΠ΅ΡΠ΅Π· pydantic-settings
|
| 119 |
β βββ data/
|
| 120 |
β β βββ loader.py # ΡΡΠ΅Π½ΠΈΠ΅ PAUQ JSON
|
| 121 |
-
β β βββ
|
|
|
|
| 122 |
β β βββ prompt.py # PromptBuilder + chat-template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
β βββ models/
|
| 124 |
β β βββ inference.py # InferenceEngine (ΠΌΠΎΠ΄Π΅Π»Ρ + LoRA)
|
| 125 |
-
β β βββ postprocess.py # ΠΎΡΠΈΡΡΠΊΠ° SQL +
|
| 126 |
β βββ evaluation/
|
| 127 |
-
β β βββ metrics.py #
|
| 128 |
-
β β βββ evaluate.py # CLI Π΄Π»Ρ ΠΏΡΠΎΠ³ΠΎΠ½Π° Π½Π° split
|
| 129 |
β βββ api/
|
| 130 |
-
β βββ main.py # FastAPI app
|
| 131 |
β βββ schemas.py # Pydantic-ΠΌΠΎΠ΄Π΅Π»ΠΈ
|
| 132 |
β βββ dependencies.py # lifespan + DI
|
| 133 |
-
|
|
|
|
| 134 |
βββ test_prompt.py
|
| 135 |
βββ test_postprocess.py
|
| 136 |
βββ test_metrics.py
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
| 138 |
```
|
| 139 |
|
| 140 |
---
|
|
@@ -153,13 +249,13 @@ python -m src.evaluation.evaluate --split dev --limit 50
|
|
| 153 |
|
| 154 |
---
|
| 155 |
|
| 156 |
-
## ΠΠ΅ΡΡΠΈΠΊΠΈ
|
| 157 |
|
| 158 |
| ΠΠΎΠ΄Π΅Π»Ρ | EM | Execution Accuracy |
|
| 159 |
|---|---|---|
|
| 160 |
-
| ruT5-base (baseline) | 25β35% | 30β40% |
|
| 161 |
-
| **Qwen2.5-Coder-3B + QLoRA** | **
|
| 162 |
-
|
|
| 163 |
|
| 164 |
---
|
| 165 |
|
|
|
|
| 12 |
ΠΠ΅Π½Π΅ΡΠ°ΡΠΈΠ²Π½Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ Π΄Π»Ρ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΡ Π²ΠΎΠΏΡΠΎΡΠΎΠ² Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² SQL-Π·Π°ΠΏΡΠΎΡΡ.
|
| 13 |
ΠΡΠ°ΠΊΡΠΈΡΠ΅ΡΠΊΠ°Ρ ΡΠ°ΡΡΡ ΠΠΠ , Π½Π°ΠΏΡΠ°Π²Π»Π΅Π½ΠΈΠ΅ Β«ΠΡΠΎΠ³ΡΠ°ΠΌΠΌΠ½Π°Ρ ΠΈΠ½ΠΆΠ΅Π½Π΅ΡΠΈΡΒ», 4 ΠΊΡΡΡ.
|
| 14 |
|
| 15 |
+
**Π‘ΡΠ΅ΠΊ:** Python 3.10+, PyTorch, transformers, PEFT (LoRA), FastAPI, Streamlit, sqlglot.
|
| 16 |
**ΠΡΠ½ΠΎΠ²Π½Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ:** Qwen2.5-Coder-3B-Instruct, Π΄ΠΎΠΎΠ±ΡΡΠ΅Π½Π½Π°Ρ ΠΌΠ΅ΡΠΎΠ΄ΠΎΠΌ QLoRA Π½Π° Π΄Π°ΡΠ°ΡΠ΅ΡΠ΅ PAUQ.
|
| 17 |
**Π‘ΡΠ°Π²Π½Π΅Π½ΠΈΠ΅:** ruT5-base baseline + GigaChat API.
|
| 18 |
|
| 19 |
+
Π‘ΠΌ. `plan_VKR_text2sql_ru.md` Π΄Π»Ρ ΠΏΠΎΠ»Π½ΠΎΠ³ΠΎ ΠΏΠ»Π°Π½Π° ΡΠ°Π±ΠΎΡ.
|
| 20 |
|
| 21 |
---
|
| 22 |
|
| 23 |
+
## ΠΡΡ
ΠΈΡΠ΅ΠΊΡΡΡΠ°
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
βββββββββββββββββββββββ HTTP ββββββββββββββββββββββββ
|
| 27 |
+
β Streamlit-ΠΊΠ»ΠΈΠ΅Π½Ρ β ββββββββββΊ β FastAPI REST API β
|
| 28 |
+
β (ΠΏΠΎΡΡ 8501) β β (ΠΏΠΎΡΡ 8000) β
|
| 29 |
+
βββββββββββββββββββββββ ββββββββββββ¬ββββββββββββ
|
| 30 |
+
β
|
| 31 |
+
ββββββββββββββββββββΌβββββββββββββββββββ
|
| 32 |
+
βΌ βΌ βΌ
|
| 33 |
+
ββββββββββββββββββββ ββββββββββββββββ ββββββββββββββββββ
|
| 34 |
+
β InferenceEngine β β SchemaProvi- β β BusinessVocab- β
|
| 35 |
+
β Qwen + LoRA β β der + Sql- β β ulary (YAML) β
|
| 36 |
+
β β β Executor β β β
|
| 37 |
+
ββββββββββββββββββββ ββββββββββββββββ ββββββββββββββββββ
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Streamlit-ΠΈΠ½ΡΠ΅ΡΡΠ΅ΠΉΡ Π½Π΅ Π²ΡΠ·ΡΠ²Π°Π΅Ρ ΠΌΠΎΠ΄Π΅Π»Ρ Π½Π°ΠΏΡΡΠΌΡΡ β ΠΎΠ½ ΠΎΠ±ΡΠ°ΡΠ°Π΅ΡΡΡ ΠΊ REST API
|
| 41 |
+
ΡΠ΅ΡΠ΅Π· `httpx`. ΠΡΠΎ ΠΏΠΎΠ·Π²ΠΎΠ»ΡΠ΅Ρ Π·Π°ΠΏΡΡΠΊΠ°ΡΡ UI ΠΈ ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ Π½Π° ΡΠ°Π·Π½ΡΡ
ΠΌΠ°ΡΠΈΠ½Π°Ρ
,
|
| 42 |
+
Π° ΡΠ°ΠΊΠΆΠ΅ ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ°ΡΡ ΠΊ API Π»ΡΠ±ΡΡ
ΡΡΠΎΡΠΎΠ½Π½ΠΈΡ
ΠΊΠ»ΠΈΠ΅Π½ΡΠΎΠ².
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## ΠΡΡΡΡΡΠΉ ΡΡΠ°ΡΡ
|
| 47 |
|
| 48 |
### 1. Π£ΡΡΠ°Π½ΠΎΠ²ΠΊΠ°
|
| 49 |
|
| 50 |
```bash
|
|
|
|
| 51 |
pip install uv
|
|
|
|
|
|
|
| 52 |
git clone <ΡΠ²ΠΎΠΉ-ΡΠ΅ΠΏΠΎ> ru2sql
|
| 53 |
cd ru2sql
|
| 54 |
uv venv
|
|
|
|
| 64 |
# cp .env.example .env # Linux/Mac
|
| 65 |
```
|
| 66 |
|
| 67 |
+
ΠΠ°ΠΏΠΎΠ»Π½ΠΈ Π² `.env`:
|
| 68 |
+
- `BASE_MODEL_NAME` (ΠΏΠΎ ΡΠΌΠΎΠ»ΡΠ°Π½ΠΈΡ Qwen/Qwen2.5-Coder-3B-Instruct),
|
| 69 |
+
- `LORA_ADAPTER_PATH` (Π»ΠΎΠΊΠ°Π»ΡΠ½Π°Ρ ΠΏΠ°ΠΏΠΊΠ° ΠΈΠ»ΠΈ HF-repo),
|
| 70 |
+
- ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎ β API-ΠΊΠ»ΡΡ GigaChat Π΄Π»Ρ baseline-ΡΡΠ°Π²Π½Π΅Π½ΠΈΡ.
|
| 71 |
+
|
| 72 |
+
ΠΡΠ»ΠΈ HuggingFace Π½Π΅Π΄ΠΎΡΡΡΠΏΠ΅Π½ ΠΈΠ· Π²Π°ΡΠ΅ΠΉ ΡΠ΅ΡΠΈ, Π΄ΠΎΠ±Π°Π²Ρ:
|
| 73 |
+
```
|
| 74 |
+
HF_ENDPOINT=https://hf-mirror.com
|
| 75 |
+
```
|
| 76 |
|
| 77 |
+
### 3. Π’Π΅ΡΡΡ
|
| 78 |
|
| 79 |
```bash
|
| 80 |
+
pytest -v
|
|
|
|
|
|
|
| 81 |
```
|
| 82 |
|
| 83 |
+
ΠΠΆΠΈΠ΄Π°Π΅ΠΌΠΎ: 80+ Π·Π΅Π»ΡΠ½ΡΡ
ΡΠ΅ΡΡΠΎΠ².
|
| 84 |
|
| 85 |
+
### 4. Smoke-ΠΏΡΠΎΠ²Π΅ΡΠΊΠ°
|
| 86 |
+
|
| 87 |
+
ΠΡΡΡΡΠ°Ρ (5 ΡΠ΅ΠΊ, Π±Π΅Π· ΠΌΠΎΠ΄Π΅Π»ΠΈ):
|
| 88 |
```bash
|
| 89 |
+
python scripts/smoke_local.py
|
| 90 |
```
|
| 91 |
|
| 92 |
+
ΠΠΎΠ»Π½Π°Ρ (Ρ Π·Π°Π³ΡΡΠ·ΠΊΠΎΠΉ Qwen, ~5 ΠΌΠΈΠ½ΡΡ Π½Π° CPU):
|
| 93 |
+
```bash
|
| 94 |
+
python scripts/smoke_local.py --with-model
|
| 95 |
+
```
|
| 96 |
|
| 97 |
+
### 5. ΠΠ°ΠΏΡΡΠΊ ΠΏΡΠΈΠ»ΠΎΠΆΠ΅Π½ΠΈΡ
|
| 98 |
|
| 99 |
+
ΠΡΠΆΠ½Ρ **Π΄Π²Π° ΠΏΡΠΎΡΠ΅ΡΡΠ°** β API ΠΈ UI:
|
| 100 |
+
|
| 101 |
+
**ΠΠΊΠ½ΠΎ 1** β REST API:
|
| 102 |
```bash
|
| 103 |
uvicorn src.api.main:app --reload
|
| 104 |
# Swagger UI: http://127.0.0.1:8000/docs
|
| 105 |
```
|
| 106 |
|
| 107 |
+
**ΠΠΊΠ½ΠΎ 2** β Streamlit-ΠΈΠ½ΡΠ΅ΡΡΠ΅ΠΉΡ:
|
| 108 |
+
```bash
|
| 109 |
+
streamlit run streamlit_app.py
|
| 110 |
+
# UI: http://127.0.0.1:8501
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
ΠΡΠΈ ΠΏΠ΅ΡΠ²ΠΎΠΌ Π·Π°ΠΏΡΡΠΊΠ΅ ΠΌΠΎΠ΄Π΅Π»Ρ Qwen2.5-Coder-3B (~6 GB) ΡΠΊΠ°ΡΠΈΠ²Π°Π΅ΡΡΡ ΠΈΠ· HuggingFace
|
| 114 |
+
Hub. ΠΠ° CPU ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ ΠΎΠ΄Π½ΠΎΠ³ΠΎ Π·Π°ΠΏΡΠΎΡΠ° Π·Π°Π½ΠΈΠΌΠ°Π΅Ρ 15β30 ΡΠ΅ΠΊΡΠ½Π΄ β ΡΡΠΎ ΠΎΠΆΠΈΠ΄Π°Π΅ΠΌΠΎ.
|
| 115 |
+
|
| 116 |
+
ΠΠ΄ΡΠ΅Ρ API ΠΌΠΎΠΆΠ½ΠΎ ΠΏΠ΅ΡΠ΅ΠΎΠΏΡΠ΅Π΄Π΅Π»ΠΈΡΡ ΠΏΠ΅ΡΠ΅ΠΌΠ΅Π½Π½ΠΎΠΉ ΠΎΠΊΡΡΠΆΠ΅Π½ΠΈΡ:
|
| 117 |
+
```bash
|
| 118 |
+
set RU2SQL_API_URL=http://192.168.1.10:8000 # Windows
|
| 119 |
+
# export RU2SQL_API_URL=http://192.168.1.10:8000 # Linux/Mac
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## REST API
|
| 125 |
+
|
| 126 |
+
### ΠΠ°Π·ΠΎΠ²ΡΠ΅ ΡΠ½Π΄ΠΏΠΎΠΈΠ½ΡΡ
|
| 127 |
+
|
| 128 |
+
| ΠΠ΅ΡΠΎΠ΄ | ΠΡΡΡ | ΠΠ°Π·Π½Π°ΡΠ΅Π½ΠΈΠ΅ |
|
| 129 |
+
|---|---|---|
|
| 130 |
+
| GET | `/health` | ΡΡΠ°ΡΡΡ ΡΠ΅ΡΠ²ΠΈΡΠ° ΠΈ Π·Π°Π³ΡΡΠΆΠ΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ |
|
| 131 |
+
| GET | `/databases` | ΡΠΏΠΈΡΠΎΠΊ ΠΠ ΠΈΠ· data/databases (PAUQ-ΡΡΡΡΠΊΡΡΡΠ°) |
|
| 132 |
+
| POST | `/generate-sql` | Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ SQL ΠΏΠΎ db_id ΠΈΠ· PAUQ |
|
| 133 |
+
| POST | `/schema` | ΡΡ
Π΅ΠΌΠ° ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ ΠΏΠΎ connection string |
|
| 134 |
+
| POST | `/query` | ΠΏΠΎΠ»Π½ΡΠΉ pipeline Π΄Π»Ρ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ |
|
| 135 |
+
|
| 136 |
+
### ΠΡΠΈΠΌΠ΅Ρ: Π·Π°ΠΏΡΠΎΡ ΠΊ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
curl -X POST http://127.0.0.1:8000/query \
|
| 140 |
+
-H "Content-Type: application/json" \
|
| 141 |
+
-d '{
|
| 142 |
+
"question": "ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ° Π·Π° 2026 Π³ΠΎΠ΄?",
|
| 143 |
+
"connection_string": "sqlite:///data/demo/sales.sqlite",
|
| 144 |
+
"execute": true,
|
| 145 |
+
"vocabulary": {
|
| 146 |
+
"company": "ΠΠ΅ΠΌΠΎ-ΠΌΠ°Π³Π°Π·ΠΈΠ½",
|
| 147 |
+
"terms": {"Π²ΡΡΡΡΠΊΠ°": "SUM(orders.amount) WHERE status='paid'"}
|
| 148 |
+
}
|
| 149 |
+
}'
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
ΠΡΠ²Π΅Ρ ΡΠΎΠ΄Π΅ΡΠΆΠΈΡ `sql`, `raw_output`, `is_valid_sql`, `gen_time_seconds` ΠΈ
|
| 153 |
+
ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎ `execution` Ρ ΡΠ΅Π·ΡΠ»ΡΡΠ°ΡΠ°ΠΌΠΈ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΡ. ΠΠ΅ΡΠ΅Π΄ ΠΈΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ΠΌ
|
| 154 |
+
SQL ΠΏΡΠΎΡ
ΠΎΠ΄ΠΈΡ AST-ΡΡΠΎΠ²Π½Π΅Π²ΡΡ ΠΏΡΠΎΠ²Π΅ΡΠΊΡ (ΡΠΌ. `is_select_only` Π²
|
| 155 |
+
`src/models/postprocess.py`) β DDL ΠΈ DML Π½Π° ΠΠ ΡΠ΅ΡΠ΅Π· API Π½Π΅Π²ΠΎΠ·ΠΌΠΎΠΆΠ½Ρ.
|
| 156 |
|
| 157 |
+
### ΠΡΠΈΠΌΠ΅Ρ: PAUQ-ΡΠ΅ΠΆΠΈΠΌ (ΡΡΠ°ΡΡΠΉ ΡΠ½Π΄ΠΏΠΎΠΈΠ½Ρ)
|
| 158 |
|
| 159 |
```bash
|
| 160 |
curl -X POST http://127.0.0.1:8000/generate-sql \
|
|
|
|
| 171 |
|
| 172 |
Π¨Π°Π³ΠΈ:
|
| 173 |
1. ΠΡΠΊΡΠΎΠΉ `notebooks/kaggle_train_qwen_qlora.ipynb` Π½Π° kaggle.com.
|
| 174 |
+
2. Π Settings Π²ΡΠ±Π΅ΡΠΈ Accelerator: GPU T4 x1.
|
| 175 |
3. Add-ons β Secrets β Π΄ΠΎΠ±Π°Π²Ρ `HF_TOKEN` ΠΈ `WANDB_API_KEY`.
|
| 176 |
4. ΠΠ°ΠΏΡΡΡΠΈ Π²ΡΠ΅ ΡΡΠ΅ΠΉΠΊΠΈ. Π’ΡΠ΅Π½ΠΈΡΠΎΠ²ΠΊΠ° ~4β6 ΡΠ°ΡΠΎΠ².
|
| 177 |
5. ΠΠΎ Π·Π°Π²Π΅ΡΡΠ΅Π½ΠΈΠΈ Π°Π΄Π°ΠΏΡΠ΅Ρ ΠΏΡΡΠΈΡΡΡ Π½Π° ΡΠ²ΠΎΠΉ ΠΏΡΠΈΠ²Π°ΡΠ½ΡΠΉ HF-ΡΠ΅ΠΏΠΎ.
|
|
|
|
| 190 |
|
| 191 |
```
|
| 192 |
ru2sql/
|
| 193 |
+
βββ pyproject.toml # Π·Π°Π²ΠΈΡΠΈΠΌΠΎΡΡΠΈ
|
| 194 |
βββ .env.example # ΡΠ°Π±Π»ΠΎΠ½ ΠΊΠΎΠ½ΡΠΈΠ³ΡΡΠ°ΡΠΈΠΈ
|
| 195 |
+
βββ plan_VKR_text2sql_ru.md # ΠΏΠ»Π°Π½ ΡΠ°Π±ΠΎΡ
|
| 196 |
βββ notebooks/
|
| 197 |
β βββ kaggle_train_qwen_qlora.ipynb
|
| 198 |
+
βββ scripts/
|
| 199 |
+
β βββ smoke_local.py # Π»ΠΎΠΊΠ°Π»ΡΠ½Π°Ρ ΠΏΡΠΎΠ²Π΅ΡΠΊΠ° ΡΠ°Π±ΠΎΡΠΎΡΠΏΠΎΡΠΎΠ±Π½ΠΎΡΡΠΈ
|
| 200 |
+
βββ configs/
|
| 201 |
+
β βββ example_vocabulary.yaml
|
| 202 |
+
β βββ sales_vocabulary.yaml
|
| 203 |
βββ src/
|
| 204 |
β βββ config.py # Π½Π°ΡΡΡΠΎΠΉΠΊΠΈ ΡΠ΅ΡΠ΅Π· pydantic-settings
|
| 205 |
β βββ data/
|
| 206 |
β β βββ loader.py # ΡΡΠ΅Π½ΠΈΠ΅ PAUQ JSON
|
| 207 |
+
β β βββ schema_provider.py # SchemaProvider β Π΅Π΄ΠΈΠ½ΡΠΉ ΠΈΠ½ΡΠ΅ΡΡΠ΅ΠΉΡ
|
| 208 |
+
β β βββ schema.py # SchemaRetriever (ΡΠ°ΡΠ°Π΄ Π΄Π»Ρ PAUQ)
|
| 209 |
β β βββ prompt.py # PromptBuilder + chat-template
|
| 210 |
+
β βββ db/
|
| 211 |
+
β β βββ connector.py # DbConnector β ΡΡΠ΅Π½ΠΈΠ΅ ΡΡ
Π΅ΠΌ
|
| 212 |
+
β β βββ executor.py # SqlExecutor Ρ read-only
|
| 213 |
+
β βββ business/
|
| 214 |
+
β β βββ vocabulary.py # BusinessVocabulary (YAML-ΠΊΠΎΠ½ΡΠΈΠ³)
|
| 215 |
β βββ models/
|
| 216 |
β β βββ inference.py # InferenceEngine (ΠΌΠΎΠ΄Π΅Π»Ρ + LoRA)
|
| 217 |
+
β β βββ postprocess.py # ΠΎΡΠΈΡΡΠΊΠ° SQL + guardrail
|
| 218 |
β βββ evaluation/
|
| 219 |
+
β β βββ metrics.py # EM + Execution Accuracy
|
| 220 |
+
β β βββ evaluate.py # CLI Π΄Π»Ρ ΠΏΡΠΎΠ³ΠΎΠ½Π° Π½Π° split
|
| 221 |
β βββ api/
|
| 222 |
+
β βββ main.py # FastAPI app (5 ΡΠ½Π΄ΠΏΠΎΠΈΠ½ΡΠΎΠ²)
|
| 223 |
β βββ schemas.py # Pydantic-ΠΌΠΎΠ΄Π΅Π»ΠΈ
|
| 224 |
β βββ dependencies.py # lifespan + DI
|
| 225 |
+
ββοΏ½οΏ½ streamlit_app.py # UI (HTTPX-ΠΊΠ»ΠΈΠ΅Π½Ρ ΠΊ API)
|
| 226 |
+
βββ tests/ # 80+ ΡΠ΅ΡΡΠΎΠ²
|
| 227 |
βββ test_prompt.py
|
| 228 |
βββ test_postprocess.py
|
| 229 |
βββ test_metrics.py
|
| 230 |
+
βββ test_schema.py
|
| 231 |
+
βββ test_schema_provider.py
|
| 232 |
+
βββ test_vocabulary.py
|
| 233 |
+
βββ test_db.py
|
| 234 |
```
|
| 235 |
|
| 236 |
---
|
|
|
|
| 249 |
|
| 250 |
---
|
| 251 |
|
| 252 |
+
## ΠΠ΅ΡΡΠΈΠΊΠΈ
|
| 253 |
|
| 254 |
| ΠΠΎΠ΄Π΅Π»Ρ | EM | Execution Accuracy |
|
| 255 |
|---|---|---|
|
| 256 |
+
| ruT5-base (baseline) | 25β35 % | 30β40 % |
|
| 257 |
+
| **Qwen2.5-Coder-3B + QLoRA** | **40,0 %** | **71,9 %** |
|
| 258 |
+
| BRIDGE / RAT-SQL (PAUQ, mono) | 51 / 52 % | 48 / 49 % |
|
| 259 |
|
| 260 |
---
|
| 261 |
|
adapters/qwen-coder-pauq-lora/README.md
CHANGED
|
@@ -1,199 +1,136 @@
|
|
| 1 |
---
|
| 2 |
-
library_name:
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
---
|
| 5 |
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
###
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
##
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
### Results
|
| 128 |
-
|
| 129 |
-
[More Information Needed]
|
| 130 |
-
|
| 131 |
-
#### Summary
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
## Model Examination [optional]
|
| 136 |
-
|
| 137 |
-
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
-
|
| 139 |
-
[More Information Needed]
|
| 140 |
-
|
| 141 |
-
## Environmental Impact
|
| 142 |
-
|
| 143 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
-
|
| 145 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
-
|
| 147 |
-
- **Hardware Type:** [More Information Needed]
|
| 148 |
-
- **Hours used:** [More Information Needed]
|
| 149 |
-
- **Cloud Provider:** [More Information Needed]
|
| 150 |
-
- **Compute Region:** [More Information Needed]
|
| 151 |
-
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
-
|
| 153 |
-
## Technical Specifications [optional]
|
| 154 |
-
|
| 155 |
-
### Model Architecture and Objective
|
| 156 |
-
|
| 157 |
-
[More Information Needed]
|
| 158 |
-
|
| 159 |
-
### Compute Infrastructure
|
| 160 |
-
|
| 161 |
-
[More Information Needed]
|
| 162 |
-
|
| 163 |
-
#### Hardware
|
| 164 |
-
|
| 165 |
-
[More Information Needed]
|
| 166 |
-
|
| 167 |
-
#### Software
|
| 168 |
-
|
| 169 |
-
[More Information Needed]
|
| 170 |
-
|
| 171 |
-
## Citation [optional]
|
| 172 |
-
|
| 173 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
-
|
| 175 |
-
**BibTeX:**
|
| 176 |
-
|
| 177 |
-
[More Information Needed]
|
| 178 |
-
|
| 179 |
-
**APA:**
|
| 180 |
-
|
| 181 |
-
[More Information Needed]
|
| 182 |
-
|
| 183 |
-
## Glossary [optional]
|
| 184 |
-
|
| 185 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
-
|
| 187 |
-
[More Information Needed]
|
| 188 |
-
|
| 189 |
-
## More Information [optional]
|
| 190 |
-
|
| 191 |
-
[More Information Needed]
|
| 192 |
-
|
| 193 |
-
## Model Card Authors [optional]
|
| 194 |
-
|
| 195 |
-
[More Information Needed]
|
| 196 |
-
|
| 197 |
-
## Model Card Contact
|
| 198 |
-
|
| 199 |
-
[More Information Needed]
|
|
|
|
| 1 |
---
|
| 2 |
+
library_name: peft
|
| 3 |
+
base_model: Qwen/Qwen2.5-Coder-3B-Instruct
|
| 4 |
+
language:
|
| 5 |
+
- ru
|
| 6 |
+
tags:
|
| 7 |
+
- text-to-sql
|
| 8 |
+
- qlora
|
| 9 |
+
- russian
|
| 10 |
+
- sql-generation
|
| 11 |
+
license: apache-2.0
|
| 12 |
+
datasets:
|
| 13 |
+
- ai-forever/PAUQ
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# qwen-coder-pauq-lora
|
| 17 |
+
|
| 18 |
+
LoRA-Π°Π΄Π°ΠΏΡΠ΅Ρ Π΄Π»Ρ ΠΌΠΎΠ΄Π΅Π»ΠΈ **Qwen2.5-Coder-3B-Instruct**, ΠΎΠ±ΡΡΠ΅Π½Π½ΡΠΉ ΠΌΠ΅ΡΠΎΠ΄ΠΎΠΌ
|
| 19 |
+
**QLoRA** Π½Π° Π΄Π°ΡΠ°ΡΠ΅ΡΠ΅ **PAUQ** Π΄Π»Ρ Π·Π°Π΄Π°ΡΠΈ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΡ Π²ΠΎΠΏΡΠΎΡΠΎΠ² Π½Π°
|
| 20 |
+
ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² SQL-Π·Π°ΠΏΡΠΎΡΡ. Π§Π°ΡΡΡ Π²ΡΠΏΡΡΠΊΠ½ΠΎΠΉ ΠΊΠ²Π°Π»ΠΈΡΠΈΠΊΠ°ΡΠΈΠΎΠ½Π½ΠΎΠΉ ΡΠ°Π±ΠΎΡΡ
|
| 21 |
+
ΠΏΠΎ Π½Π°ΠΏΡΠ°Π²Π»Π΅Π½ΠΈΡ 09.03.04 Β«ΠΡΠΎΠ³ΡΠ°ΠΌΠΌΠ½Π°Ρ ΠΈΠ½ΠΆΠ΅Π½Π΅ΡΠΈΡΒ», ΠΠΠΠ’Π£-ΠΠΠ.
|
| 22 |
+
|
| 23 |
+
## ΠΠΏΠΈΡΠ°Π½ΠΈΠ΅
|
| 24 |
+
|
| 25 |
+
ΠΠ΄Π°ΠΏΡΠ΅Ρ ΠΎΠ±ΡΡΠ°Π΅Ρ ΠΌΠΎΠ΄Π΅Π»Ρ ΠΎΡΠ²Π΅ΡΠ°ΡΡ Π½Π° ΡΡΡΡΠΊΠΎΡΠ·ΡΡΠ½ΡΠΉ Π°Π½Π°Π»ΠΈΡΠΈΡΠ΅ΡΠΊΠΈΠΉ Π²ΠΎΠΏΡΠΎΡ
|
| 26 |
+
ΡΠΈΠ½ΡΠ°ΠΊΡΠΈΡΠ΅ΡΠΊΠΈ ΠΊΠΎΡΡΠ΅ΠΊΡΠ½ΡΠΌ SQL-Π·Π°ΠΏΡΠΎΡΠΎΠΌ, ΡΡΠΈΡΡΠ²Π°Ρ ΡΡ
Π΅ΠΌΡ ΠΊΠΎΠ½ΠΊΡΠ΅ΡΠ½ΠΎΠΉ Π±Π°Π·Ρ
|
| 27 |
+
Π΄Π°Π½Π½ΡΡ
, ΠΏΠ΅ΡΠ΅Π΄Π°Π½Π½ΡΡ Π² ΡΠΈΡΡΠ΅ΠΌΠ½ΠΎΠΌ ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠΈ Π²ΠΌΠ΅ΡΡΠ΅ Ρ ΠΏΡΠΈΠΌΠ΅ΡΠ°ΠΌΠΈ ΡΡΡΠΎΠΊ.
|
| 28 |
+
|
| 29 |
+
ΠΠ°Π·ΠΎΠ²Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ ΠΎΡΡΠ°ΡΡΡΡ Π·Π°ΠΌΠΎΡΠΎΠΆΠ΅Π½Π½ΠΎΠΉ, ΠΎΠ±ΡΡΠ°ΡΡΡΡ ΡΠΎΠ»ΡΠΊΠΎ LoRA-ΠΌΠ°ΡΡΠΈΡΡ
|
| 30 |
+
ΡΠ°Π½Π³ΠΎΠΌ 16, Π½Π°Π»ΠΎΠΆΠ΅Π½Π½ΡΠ΅ Π½Π° Π²ΡΠ΅ ΠΏΡΠΎΠ΅ΠΊΡΠΈΠΎΠ½Π½ΡΠ΅ ΡΠ»ΠΎΠΈ attention ΠΈ MLP. ΠΡΠΎ
|
| 31 |
+
ΠΏΠΎΠ·Π²ΠΎΠ»ΡΠ΅Ρ Ρ
ΡΠ°Π½ΠΈΡΡ ΠΈ ΡΠ°ΡΠΏΡΠΎΡΡΡΠ°Π½ΡΡΡ Π°Π΄Π°ΠΏΡΠ΅Ρ ΡΠ°Π·ΠΌΠ΅ΡΠΎΠΌ Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΎ Π΄Π΅ΡΡΡΠΊΠΎΠ²
|
| 32 |
+
ΠΌΠ΅Π³Π°Π±Π°ΠΉΡ, Π° Π½Π΅ ΠΏΠΎΠ»Π½ΡΠ΅ Π²Π΅ΡΠ° ΠΌΠΎΠ΄Π΅Π»ΠΈ Π² Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΎ Π³ΠΈΠ³Π°Π±Π°ΠΉΡ.
|
| 33 |
+
|
| 34 |
+
## ΠΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΠ΅
|
| 35 |
+
|
| 36 |
+
```python
|
| 37 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 38 |
+
from peft import PeftModel
|
| 39 |
+
|
| 40 |
+
base = "Qwen/Qwen2.5-Coder-3B-Instruct"
|
| 41 |
+
adapter = "Tyycha/qwen-coder-pauq-lora"
|
| 42 |
+
|
| 43 |
+
tokenizer = AutoTokenizer.from_pretrained(base)
|
| 44 |
+
model = AutoModelForCausalLM.from_pretrained(base, device_map="auto")
|
| 45 |
+
model = PeftModel.from_pretrained(model, adapter)
|
| 46 |
+
model.eval()
|
| 47 |
+
|
| 48 |
+
messages = [
|
| 49 |
+
{"role": "system", "content": "Π’Ρ β Π°ΡΡΠΈΡΡΠ΅Π½Ρ, ΠΊΠΎΡΠΎΡΡΠΉ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΡΠ΅Ρ Π²ΠΎΠΏΡΠΎΡΡ Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² SQL..."},
|
| 50 |
+
{"role": "user", "content": "### Schema:\nCREATE TABLE orders (id INT, amount REAL, status TEXT);\n\n### Question:\nΠΠ°ΠΊΠ°Ρ ΡΡΠΌΠΌΠ°ΡΠ½Π°Ρ Π²ΡΡΡΡΠΊΠ° ΠΏΠΎ ΠΎΠΏΠ»Π°ΡΠ΅Π½Π½ΡΠΌ Π·Π°ΠΊΠ°Π·Π°ΠΌ?\n\n### SQL:\n"},
|
| 51 |
+
]
|
| 52 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 53 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 54 |
+
output = model.generate(**inputs, max_new_tokens=256, do_sample=False)
|
| 55 |
+
print(tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
ΠΠΎΠ»Π½ΡΠΉ pipeline (ΡΠΎΡΠΌΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ ΠΏΡΠΎΠΌΠΏΡΠ°, ΠΏΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠ°, ΠΈΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ SQL)
|
| 59 |
+
ΡΠ΅Π°Π»ΠΈΠ·ΠΎΠ²Π°Π½ Π² ΠΏΡΠΎΠ΅ΠΊΡΠ΅ [Ru2SQL](https://github.com/Tyycha/Ru2SQL).
|
| 60 |
+
|
| 61 |
+
## ΠΠ°Π½Π½ΡΠ΅
|
| 62 |
+
|
| 63 |
+
- **ΠΠ°ΡΠ°ΡΠ΅Ρ:** PAUQ (Bakshandaeva et al., 2022) β ΠΏΠ΅ΡΠ²ΡΠΉ ΠΊΡΡΠΏΠ½ΡΠΉ ΡΡΡΡΠΊΠΎΡΠ·ΡΡΠ½ΡΠΉ
|
| 64 |
+
ΠΊΠΎΡΠΏΡΡ Π΄Π»Ρ Π·Π°Π΄Π°ΡΠΈ Text-to-SQL, ΠΏΠΎΡΡΡΠΎΠ΅Π½Π½ΡΠΉ Π½Π° ΠΎΡΠ½ΠΎΠ²Π΅ Spider.
|
| 65 |
+
- **Π Π°Π·ΠΌΠ΅Ρ:** ~7 ΡΡΡ. ΠΏΠ°Ρ (Π²ΠΎΠΏΡΠΎΡ, SQL) Π² train, 1076 Π² dev.
|
| 66 |
+
- **Π‘ΡΡΡΠΊΡΡΡΠ° ΠΏΡΠΈΠΌΠ΅ΡΠ°:** Π΄ΠΈΠ°Π»ΠΎΠ³ Π² ΡΠΎΡΠΌΠ°ΡΠ΅ chat-template ΠΈΠ· ΡΡΡΡ
ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠΉ
|
| 67 |
+
(`system` Ρ ΠΈΠ½ΡΡΡΡΠΊΡΠΈΠ΅ΠΉ ΠΈ ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΡΠΌ Π±ΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡΠΌ, `user` ΡΠΎ
|
| 68 |
+
ΡΡ
Π΅ΠΌΠΎΠΉ ΠΈ Π²ΠΎΠΏΡΠΎΡΠΎΠΌ, `assistant` Ρ ΡΡΠ°Π»ΠΎΠ½Π½ΡΠΌ SQL).
|
| 69 |
+
|
| 70 |
+
## ΠΠΈΠΏΠ΅ΡΠΏΠ°ΡΠ°ΠΌΠ΅ΡΡΡ ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ
|
| 71 |
+
|
| 72 |
+
| ΠΠ°ΡΠ°ΠΌΠ΅ΡΡ | ΠΠ½Π°ΡΠ΅Π½ΠΈΠ΅ |
|
| 73 |
+
|---|---|
|
| 74 |
+
| ΠΠ°Π·ΠΎΠ²Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ | Qwen/Qwen2.5-Coder-3B-Instruct |
|
| 75 |
+
| ΠΠ΅ΡΠΎΠ΄ | QLoRA (NF4, double quantization, compute dtype = bfloat16) |
|
| 76 |
+
| LoRA rank `r` | 16 |
|
| 77 |
+
| LoRA alpha | 32 |
|
| 78 |
+
| LoRA dropout | 0.05 |
|
| 79 |
+
| Target modules | q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
|
| 80 |
+
| ΠΠΏΠΎΡ
ΠΈ | 2 |
|
| 81 |
+
| Π Π°Π·ΠΌΠ΅Ρ Π±Π°ΡΡΠ° | 1 (Π½Π° ΡΡΡΡΠΎΠΉΡΡΠ²ΠΎ) |
|
| 82 |
+
| ΠΠ°ΠΊΠΎΠΏΠ»Π΅Π½ΠΈΠ΅ Π³ΡΠ°Π΄ΠΈΠ΅Π½ΡΠΎΠ² | 8 ΡΠ°Π³ΠΎΠ² (ΡΡΡΠ΅ΠΊΡΠΈΠ²Π½ΡΠΉ Π±Π°ΡΡ 8) |
|
| 83 |
+
| Π‘ΠΊΠΎΡΠΎΡΡΡ ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ | 2e-4 (cosine, warmup 3 %) |
|
| 84 |
+
| Max sequence length | 1024 |
|
| 85 |
+
| Π’ΠΎΡΠ½ΠΎΡΡΡ | bfloat16 |
|
| 86 |
+
| ΠΠ±ΠΎΡΡΠ΄ΠΎΠ²Π°Π½ΠΈΠ΅ | NVIDIA Tesla T4 16 GB (Kaggle) |
|
| 87 |
+
|
| 88 |
+
## ΠΠ΅ΡΡΠΈΠΊΠΈ
|
| 89 |
+
|
| 90 |
+
ΠΡΠ΅Π½ΠΊΠ° Π½Π° Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΠΎΠ½Π½ΠΎΠΉ Π²ΡΠ±ΠΎΡΠΊΠ΅ PAUQ (1076 ΠΏΡΠΈΠΌΠ΅ΡΠΎΠ², 166 Π±Π°Π· Π΄Π°Π½Π½ΡΡ
):
|
| 91 |
+
|
| 92 |
+
| ΠΠ΅ΡΡΠΈΠΊΠ° | ΠΠ½Π°ΡΠ΅Π½ΠΈΠ΅ |
|
| 93 |
+
|---|---|
|
| 94 |
+
| Exact Match (EM) | 40.0 % (430 / 1076) |
|
| 95 |
+
| Execution Accuracy (EX) | 71.9 % (772 / 1074) |
|
| 96 |
+
|
| 97 |
+
ΠΠ»Ρ ΡΡΠ°Π²Π½Π΅Π½ΠΈΡ, ΠΎΠΏΡΠ±Π»ΠΈΠΊΠΎΠ²Π°Π½Π½ΡΠ΅ ΡΠ΅Π·ΡΠ»ΡΡΠ°ΡΡ Π½Π° ΡΠΎΠΉ ΠΆΠ΅ Π²ΡΠ±ΠΎΡΠΊΠ΅:
|
| 98 |
+
|
| 99 |
+
| ΠΠΎΠ΄Π΅Π»Ρ | EM | EX |
|
| 100 |
+
|---|---|---|
|
| 101 |
+
| RAT-SQL (PAUQ, mono) | 51 % | 49 % |
|
| 102 |
+
| BRIDGE (PAUQ, mono) | 52 % | 48 % |
|
| 103 |
+
| **Qwen2.5-Coder-3B + QLoRA (ΡΡΠ° ΡΠ°Π±ΠΎΡΠ°)** | **40 %** | **71.9 %** |
|
| 104 |
+
|
| 105 |
+
ΠΡΡΠΎΠΊΠΈΠΉ ΡΠ°Π·ΡΡΠ² ΠΌΠ΅ΠΆΠ΄Ρ EM ΠΈ EX ΠΎΡΡΠ°ΠΆΠ°Π΅Ρ ΡΠΏΠ΅ΡΠΈΡΠΈΠΊΡ Π·Π°Π΄Π°ΡΠΈ: ΠΌΠΎΠ΄Π΅Π»Ρ
|
| 106 |
+
Π³Π΅Π½Π΅ΡΠΈΡΡΠ΅Ρ ΡΠ΅ΠΌΠ°Π½ΡΠΈΡΠ΅ΡΠΊΠΈ ΠΊΠΎΡΡΠ΅ΠΊΡΠ½ΡΠ΅ Π·Π°ΠΏΡΠΎΡΡ, Π½ΠΎ ΡΠΈΠ½ΡΠ°ΠΊΡΠΈΡΠ΅ΡΠΊΠΈ
|
| 107 |
+
ΠΎΡΠ»ΠΈΡΠ°ΡΡΠΈΠ΅ΡΡ ΠΎΡ ΡΡΠ°Π»ΠΎΠ½Π½ΡΡ
(Π΄ΡΡΠ³ΠΎΠΉ ΠΏΠΎΡΡΠ΄ΠΎΠΊ ΡΡΠ»ΠΎΠ²ΠΈΠΉ, Π°Π»ΡΡΠ΅ΡΠ½Π°ΡΠΈΠ²Π½ΡΠ΅
|
| 108 |
+
JOIN-ΡΡΡΠ°ΡΠ΅Π³ΠΈΠΈ, ΠΈΠ½ΡΠ΅ ΠΏΡΠ΅Π²Π΄ΠΎΠ½ΠΈΠΌΡ).
|
| 109 |
+
|
| 110 |
+
## ΠΠ³ΡΠ°Π½ΠΈΡΠ΅Π½ΠΈΡ
|
| 111 |
+
|
| 112 |
+
- ΠΠΎΠ΄Π΅Π»Ρ ΠΎΠ±ΡΡΠ΅Π½Π° Π½Π° ΡΡ
Π΅ΠΌΠ°Ρ
PAUQ/Spider. ΠΠ° ΡΡΡΠ΅ΡΡΠ²Π΅Π½Π½ΠΎ ΠΎΡΠ»ΠΈΡΠ°ΡΡΠΈΡ
ΡΡ
|
| 113 |
+
ΠΏΡΠ΅Π΄ΠΌΠ΅ΡΠ½ΡΡ
ΠΎΠ±Π»Π°ΡΡΡΡ
ΠΊΠ°ΡΠ΅ΡΡΠ²ΠΎ ΠΌΠΎΠΆΠ΅Ρ ΠΏΠ°Π΄Π°ΡΡ; Π΄Π»Ρ Π°Π΄Π°ΠΏΡΠ°ΡΠΈΠΈ
|
| 114 |
+
ΠΏΡΠ΅Π΄ΡΡΠΌΠΎΡΡΠ΅Π½ ΠΌΠ΅Ρ
Π°Π½ΠΈΠ·ΠΌ Π±ΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ Π² ΠΎΡΠ½ΠΎΠ²Π½ΠΎΠΌ ΠΏΡΠΎΠ΅ΠΊΡΠ΅.
|
| 115 |
+
- ΠΠ°ΠΈΠ±ΠΎΠ»Π΅Π΅ ΡΠ»ΠΎΠΆΠ½ΡΠ΅ ΠΊΠ»Π°ΡΡΡ Π·Π°ΠΏΡΠΎΡΠΎΠ² β Π²Π»ΠΎΠΆΠ΅Π½Π½ΡΠ΅ SELECT, ΠΊΠΎΠ½ΡΡΡΡΠΊΡΠΈΠΈ
|
| 116 |
+
INTERSECT/EXCEPT/UNION, HAVING Ρ Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΈΠΌΠΈ ΡΡΠ»ΠΎΠ²ΠΈΡΠΌΠΈ β ΠΎΡΡΠ°ΡΡΡΡ
|
| 117 |
+
ΡΠ»Π°Π±ΡΠΌ ΠΌΠ΅ΡΡΠΎΠΌ.
|
| 118 |
+
- ΠΡΠΈ ΡΠ°Π±ΠΎΡΠ΅ Π½Π° CPU Π±Π΅Π· GPU ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ Π·Π°Π½ΠΈΠΌΠ°Π΅Ρ 15β30 ΡΠ΅ΠΊΡΠ½Π΄ Π½Π° Π·Π°ΠΏΡΠΎΡ.
|
| 119 |
+
- ΠΠ΄Π°ΠΏΡΠ΅Ρ ΠΎΡΠΈΠ΅Π½ΡΠΈΡΠΎΠ²Π°Π½ Π½Π° SQLite. ΠΠ»Ρ PostgreSQL/MySQL ΡΠ΅ΠΊΠΎΠΌΠ΅Π½Π΄ΡΠ΅ΡΡΡ
|
| 120 |
+
ΡΠ²Π½ΠΎ ΡΠΊΠ°Π·Π°ΡΡ Π΄ΠΈΠ°Π»Π΅ΠΊΡ Π² ΠΏΡΠΎΠΌΠΏΡΠ΅.
|
| 121 |
+
|
| 122 |
+
## ΠΠΈΡΠ΅Π½Π·ΠΈΡ
|
| 123 |
+
|
| 124 |
+
Apache 2.0. ΠΠ°Π·ΠΎΠ²Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ Qwen2.5-Coder ΡΠ°ΠΊΠΆΠ΅ ΡΠ°ΡΠΏΡΠΎΡΡΡΠ°Π½ΡΠ΅ΡΡΡ ΠΏΠΎΠ΄
|
| 125 |
+
Π»ΠΈΡΠ΅Π½Π·ΠΈΠ΅ΠΉ Apache 2.0; Π΄Π°ΡΠ°ΡΠ΅Ρ PAUQ β Apache 2.0.
|
| 126 |
+
|
| 127 |
+
## Π¦ΠΈΡΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅
|
| 128 |
+
|
| 129 |
+
```bibtex
|
| 130 |
+
@misc{ru2sql-qlora-2026,
|
| 131 |
+
title = {Ru2SQL: Russian Text-to-SQL via QLoRA on Qwen2.5-Coder},
|
| 132 |
+
author = {Siryazeev, Danis},
|
| 133 |
+
year = {2026},
|
| 134 |
+
note = {Bachelor's thesis, Kazan National Research Technical University},
|
| 135 |
+
}
|
| 136 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/demo/sales.sqlite-journal
DELETED
|
Binary file (512 Bytes)
|
|
|
evaluate_pauq.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Π‘ΠΊΡΠΈΠΏΡ ΠΏΡΠΎΠ³ΠΎΠ½Π° ΠΌΠ΅ΡΡΠΈΠΊ Ru2SQL Π½Π° Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΠΎΠ½Π½ΠΎΠΉ Π²ΡΠ±ΠΎΡΠΊΠ΅ PAUQ.
|
| 3 |
+
|
| 4 |
+
Π‘ΡΠΈΡΠ°Π΅Ρ Exact Match (EM) ΠΈ Execution Accuracy (EX) Π΄Π»Ρ Π΄ΠΎΠΎΠ±ΡΡΠ΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
|
| 5 |
+
Qwen2.5-Coder-3B + QLoRA. ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ Π½Π° Kaggle (T4 GPU), ΡΠ΅Π·ΡΠ»ΡΡΠ°Ρ
|
| 6 |
+
ΡΠΎΡ
ΡΠ°Π½ΡΠ΅ΡΡΡ Π² eval_results.csv ΠΈ eval_summary.json.
|
| 7 |
+
|
| 8 |
+
ΠΡΡΠΎΡΠ½ΠΈΠΊ ΡΠΈΡΡ Π΄Π»Ρ ΡΠ°Π·Π΄Π΅Π»Π° 4.3 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ.
|
| 9 |
+
|
| 10 |
+
ΠΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π½ΠΈΠ΅ Π½Π° Kaggle:
|
| 11 |
+
1. ΠΠ°Π³ΡΡΠ·ΠΈΡΡ ΠΏΡΠΎΠ΅ΠΊΡ ΡΠ΅Π»ΠΈΠΊΠΎΠΌ (ΡΠ΅ΡΠ΅Π· Kaggle Dataset ΠΈΠ»ΠΈ git clone).
|
| 12 |
+
2. Π£ΡΡΠ°Π½ΠΎΠ²ΠΈΡΡ ADAPTER_ID Π½Π° ΡΠ²ΠΎΠΉ HF-ΡΠ΅ΠΏΠΎ.
|
| 13 |
+
3. python evaluate_pauq.py
|
| 14 |
+
|
| 15 |
+
ΠΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠ° ΠΈ Π½ΠΎΡΠΌΠ°Π»ΠΈΠ·Π°ΡΠΈΡ SQL ΠΈΠΌΠΏΠΎΡΡΠΈΡΡΡΡΡΡ ΠΈΠ· src/models/postprocess.py,
|
| 16 |
+
ΡΡΠΎΠ±Ρ ΠΌΠ΅ΡΡΠΈΠΊΠΈ Π»ΠΎΠΊΠ°Π»ΡΠ½ΠΎ ΠΈ Π½Π° Kaggle ΠΎΡΡΠ°Π²Π°Π»ΠΈΡΡ ΡΠΎΠΏΠΎΡΡΠ°Π²ΠΈΠΌΡΠΌΠΈ.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
# ΠΠ΅Π»Π°Π΅ΠΌ ΠΏΠ°ΠΊΠ΅Ρ src/ ΠΈΠΌΠΏΠΎΡΡΠΈΡΡΠ΅ΠΌΡΠΌ, ΠΊΠΎΠ³Π΄Π° ΡΠΊΡΠΈΠΏΡ Π·Π°ΠΏΡΡΠΊΠ°Π΅ΡΡΡ ΠΈΠ· ΠΊΠΎΡΠ½Ρ.
|
| 23 |
+
_PROJECT_ROOT = Path(__file__).resolve().parent
|
| 24 |
+
if str(_PROJECT_ROOT) not in sys.path:
|
| 25 |
+
sys.path.insert(0, str(_PROJECT_ROOT))
|
| 26 |
+
|
| 27 |
+
import csv
|
| 28 |
+
import json
|
| 29 |
+
import sqlite3
|
| 30 |
+
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 35 |
+
from peft import PeftModel
|
| 36 |
+
from datasets import load_dataset
|
| 37 |
+
|
| 38 |
+
from src.models.postprocess import (
|
| 39 |
+
normalize_sql,
|
| 40 |
+
strip_model_artifacts as strip_artifacts,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# βββ CONFIG βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct"
|
| 46 |
+
ADAPTER_ID = "Tyycha/qwen-coder-pauq-lora"
|
| 47 |
+
PAUQ_SPLIT = "validation" # 1034 ΠΏΡΠΈΠΌΠ΅ΡΠ°
|
| 48 |
+
MAX_NEW_TOKENS = 256
|
| 49 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
+
|
| 51 |
+
# ΠΡΡΡ ΠΊ ΠΏΠ°ΠΏΠΊΠ΅ Ρ SQLite-Π±Π°Π·Π°ΠΌΠΈ PAUQ (Spider databases)
|
| 52 |
+
# ΠΠ° Kaggle: ΡΠΊΠ°ΡΠ°ΠΉ https://drive.google.com/uc?id=1iRDVHLr4mX2wQKSgA9VxUUFpj-3-Kj5B
|
| 53 |
+
# ΠΠ»ΠΈ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠΉ Π΄Π°ΡΠ°ΡΠ΅Ρ: kaggle datasets download -d wikisql/spider
|
| 54 |
+
PAUQ_DB_DIR = Path("./pauq_databases") # ΠΏΠ°ΠΏΠΊΠ°, Π³Π΄Π΅ Π»Π΅ΠΆΠ°Ρ ΠΏΠ°ΠΏΠΊΠΈ Π±Π°Π· Π΄Π°Π½Π½ΡΡ
|
| 55 |
+
|
| 56 |
+
# βββ OUTPUT βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
+
|
| 58 |
+
RESULTS_FILE = "eval_results.csv"
|
| 59 |
+
SUMMARY_FILE = "eval_summary.json"
|
| 60 |
+
|
| 61 |
+
# βββ LOAD MODEL βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
|
| 63 |
+
print("Loading model...")
|
| 64 |
+
bnb_config = BitsAndBytesConfig(
|
| 65 |
+
load_in_4bit=True,
|
| 66 |
+
bnb_4bit_quant_type="nf4",
|
| 67 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 68 |
+
bnb_4bit_use_double_quant=True,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
|
| 72 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 73 |
+
BASE_MODEL_ID,
|
| 74 |
+
quantization_config=bnb_config,
|
| 75 |
+
device_map="auto",
|
| 76 |
+
trust_remote_code=True
|
| 77 |
+
)
|
| 78 |
+
model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
|
| 79 |
+
model.eval()
|
| 80 |
+
print(f"Model loaded on {DEVICE}")
|
| 81 |
+
|
| 82 |
+
# βββ SCHEMA RETRIEVER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
|
| 84 |
+
def get_schema(db_id: str) -> str:
|
| 85 |
+
"""Extract CREATE TABLE statements from SQLite database."""
|
| 86 |
+
db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite"
|
| 87 |
+
if not db_path.exists():
|
| 88 |
+
return f"-- Database {db_id} not found"
|
| 89 |
+
|
| 90 |
+
conn = sqlite3.connect(str(db_path))
|
| 91 |
+
cursor = conn.cursor()
|
| 92 |
+
|
| 93 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
| 94 |
+
tables = [row[0] for row in cursor.fetchall()]
|
| 95 |
+
|
| 96 |
+
schema_parts = []
|
| 97 |
+
for table in tables:
|
| 98 |
+
cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table,))
|
| 99 |
+
create_sql = cursor.fetchone()
|
| 100 |
+
if create_sql and create_sql[0]:
|
| 101 |
+
schema_parts.append(create_sql[0])
|
| 102 |
+
# Add sample rows
|
| 103 |
+
try:
|
| 104 |
+
cursor.execute(f"SELECT * FROM \"{table}\" LIMIT 3")
|
| 105 |
+
rows = cursor.fetchall()
|
| 106 |
+
cursor.execute(f"PRAGMA table_info(\"{table}\")")
|
| 107 |
+
cols = [col[1] for col in cursor.fetchall()]
|
| 108 |
+
if rows:
|
| 109 |
+
schema_parts.append(f"-- Sample data for {table}:")
|
| 110 |
+
schema_parts.append("-- " + " | ".join(cols))
|
| 111 |
+
for row in rows[:3]:
|
| 112 |
+
schema_parts.append("-- " + " | ".join(str(v) for v in row))
|
| 113 |
+
except Exception:
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
conn.close()
|
| 117 |
+
return "\n\n".join(schema_parts)
|
| 118 |
+
|
| 119 |
+
# βββ PROMPT BUILDER βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
+
|
| 121 |
+
SYSTEM_PROMPT = (
|
| 122 |
+
"Π’Ρ β Π°ΡΡΠΈΡΡΠ΅Π½Ρ, ΠΊΠΎΡΠΎΡΡΠΉ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΡΠ΅Ρ Π²ΠΎΠΏΡΠΎΡΡ Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² SQL-Π·Π°ΠΏΡΠΎΡΡ. "
|
| 123 |
+
"ΠΡΠ²Π΅ΡΠ°ΠΉ Π’ΠΠΠ¬ΠΠ SQL-Π·Π°ΠΏΡΠΎΡΠΎΠΌ Π±Π΅Π· ΠΎΠ±ΡΡΡΠ½Π΅Π½ΠΈΠΉ, ΠΊΠΎΠΌΠΌΠ΅Π½ΡΠ°ΡΠΈΠ΅Π² ΠΈ markdown-ΡΠ°Π·ΠΌΠ΅ΡΠΊΠΈ."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def build_prompt(question: str, schema: str) -> str:
|
| 127 |
+
user_msg = f"Schema:\n{schema}\n\nQuestion: {question}\n\nSQL:"
|
| 128 |
+
messages = [
|
| 129 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 130 |
+
{"role": "user", "content": user_msg},
|
| 131 |
+
]
|
| 132 |
+
return tokenizer.apply_chat_template(
|
| 133 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# βββ SQL POSTPROCESSING βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 137 |
+
# strip_artifacts ΠΈ normalize_sql ΠΈΠΌΠΏΠΎΡΡΠΈΡΡΡΡΡΡ ΠΈΠ· src/models/postprocess.py
|
| 138 |
+
# Π΄Π»Ρ Π³Π°ΡΠ°Π½ΡΠΈΠΈ ΡΠΎΠ³ΠΎ, ΡΡΠΎ ΠΌΠ΅ΡΡΠΈΠΊΠΈ Π½Π° Kaggle ΠΈ Π² Π»ΠΎΠΊΠ°Π»ΡΠ½ΡΡ
ΡΠ΅ΡΡΠ°Ρ
ΡΡΠΈΡΠ°ΡΡΡΡ
|
| 139 |
+
# ΠΏΠΎ ΠΎΠ΄Π½ΠΎΠΉ ΠΈ ΡΠΎΠΉ ΠΆΠ΅ Π»ΠΎΠ³ΠΈΠΊΠ΅.
|
| 140 |
+
|
| 141 |
+
# βββ EXECUTION ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
|
| 143 |
+
def execute_sql(sql: str, db_id: str):
|
| 144 |
+
"""Execute SQL and return result set as frozenset of tuples."""
|
| 145 |
+
db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite"
|
| 146 |
+
if not db_path.exists():
|
| 147 |
+
return None
|
| 148 |
+
try:
|
| 149 |
+
uri = f"file:{db_path}?mode=ro"
|
| 150 |
+
conn = sqlite3.connect(uri, uri=True)
|
| 151 |
+
cursor = conn.cursor()
|
| 152 |
+
cursor.execute(sql)
|
| 153 |
+
rows = cursor.fetchall()
|
| 154 |
+
conn.close()
|
| 155 |
+
return frozenset(tuple(str(v) for v in row) for row in rows)
|
| 156 |
+
except Exception:
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
# βββ INFERENCE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 160 |
+
|
| 161 |
+
def generate_sql(question: str, schema: str) -> str:
|
| 162 |
+
prompt = build_prompt(question, schema)
|
| 163 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 164 |
+
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
outputs = model.generate(
|
| 167 |
+
**inputs,
|
| 168 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 169 |
+
do_sample=False,
|
| 170 |
+
temperature=1.0,
|
| 171 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 175 |
+
return strip_artifacts(generated)
|
| 176 |
+
|
| 177 |
+
# βββ MAIN EVALUATION LOOP βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 178 |
+
|
| 179 |
+
def main():
|
| 180 |
+
print(f"Loading PAUQ {PAUQ_SPLIT} split...")
|
| 181 |
+
dataset = load_dataset("ai-forever/PAUQ", split=PAUQ_SPLIT)
|
| 182 |
+
print(f"Loaded {len(dataset)} examples")
|
| 183 |
+
|
| 184 |
+
results = []
|
| 185 |
+
em_correct = 0
|
| 186 |
+
ex_correct = 0
|
| 187 |
+
ex_total = 0 # only count where execution was possible
|
| 188 |
+
|
| 189 |
+
with open(RESULTS_FILE, "w", newline="", encoding="utf-8") as f:
|
| 190 |
+
writer = csv.DictWriter(f, fieldnames=[
|
| 191 |
+
"idx", "db_id", "question", "gold_sql", "pred_sql",
|
| 192 |
+
"em", "ex", "error"
|
| 193 |
+
])
|
| 194 |
+
writer.writeheader()
|
| 195 |
+
|
| 196 |
+
for idx, example in enumerate(tqdm(dataset, desc="Evaluating")):
|
| 197 |
+
question = example.get("question_ru") or example.get("question", "")
|
| 198 |
+
gold_sql = example.get("query", "")
|
| 199 |
+
db_id = example.get("db_id", "")
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
schema = get_schema(db_id)
|
| 203 |
+
pred_sql = generate_sql(question, schema)
|
| 204 |
+
|
| 205 |
+
# Exact Match
|
| 206 |
+
norm_pred = normalize_sql(pred_sql)
|
| 207 |
+
norm_gold = normalize_sql(gold_sql)
|
| 208 |
+
em = int(norm_pred == norm_gold)
|
| 209 |
+
|
| 210 |
+
# Execution Accuracy
|
| 211 |
+
pred_result = execute_sql(pred_sql, db_id)
|
| 212 |
+
gold_result = execute_sql(gold_sql, db_id)
|
| 213 |
+
|
| 214 |
+
if gold_result is not None:
|
| 215 |
+
ex = int(pred_result == gold_result)
|
| 216 |
+
ex_correct += ex
|
| 217 |
+
ex_total += 1
|
| 218 |
+
else:
|
| 219 |
+
ex = None
|
| 220 |
+
|
| 221 |
+
em_correct += em
|
| 222 |
+
error = ""
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
pred_sql = ""
|
| 226 |
+
em = 0
|
| 227 |
+
ex = None
|
| 228 |
+
error = str(e)[:200]
|
| 229 |
+
|
| 230 |
+
row = {
|
| 231 |
+
"idx": idx, "db_id": db_id, "question": question,
|
| 232 |
+
"gold_sql": gold_sql, "pred_sql": pred_sql,
|
| 233 |
+
"em": em, "ex": ex, "error": error
|
| 234 |
+
}
|
| 235 |
+
writer.writerow(row)
|
| 236 |
+
results.append(row)
|
| 237 |
+
|
| 238 |
+
# Progress every 100 examples
|
| 239 |
+
if (idx + 1) % 100 == 0:
|
| 240 |
+
cur_em = em_correct / (idx + 1)
|
| 241 |
+
cur_ex = ex_correct / max(ex_total, 1)
|
| 242 |
+
print(f"[{idx+1}/{len(dataset)}] EM={cur_em:.3f} EX={cur_ex:.3f}")
|
| 243 |
+
|
| 244 |
+
# Final summary
|
| 245 |
+
n = len(dataset)
|
| 246 |
+
final_em = em_correct / n
|
| 247 |
+
final_ex = ex_correct / max(ex_total, 1)
|
| 248 |
+
|
| 249 |
+
summary = {
|
| 250 |
+
"model": ADAPTER_ID,
|
| 251 |
+
"split": PAUQ_SPLIT,
|
| 252 |
+
"n_examples": n,
|
| 253 |
+
"exact_match": round(final_em, 4),
|
| 254 |
+
"execution_accuracy": round(final_ex, 4),
|
| 255 |
+
"em_correct": em_correct,
|
| 256 |
+
"ex_correct": ex_correct,
|
| 257 |
+
"ex_total": ex_total
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
with open(SUMMARY_FILE, "w", encoding="utf-8") as f:
|
| 261 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 262 |
+
|
| 263 |
+
print("\n" + "="*50)
|
| 264 |
+
print(f"RESULTS on PAUQ {PAUQ_SPLIT} ({n} examples)")
|
| 265 |
+
print(f" Exact Match (EM): {final_em:.1%} ({em_correct}/{n})")
|
| 266 |
+
print(f" Execution Accuracy (EX): {final_ex:.1%} ({ex_correct}/{ex_total})")
|
| 267 |
+
print(f"\nDetailed results saved to: {RESULTS_FILE}")
|
| 268 |
+
print(f"Summary saved to: {SUMMARY_FILE}")
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
main()
|
scripts/run_app.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ΠΠ°ΡΠ½ΡΠ΅Ρ Π΄Π²ΡΡ
ΠΏΡΠΎΡΠ΅ΡΡΠΎΠ² Ru2SQL: FastAPI + Streamlit.
|
| 2 |
+
|
| 3 |
+
ΠΠ°ΠΏΡΡΠΊ:
|
| 4 |
+
python scripts/run_app.py
|
| 5 |
+
|
| 6 |
+
Π‘ΠΊΡΠΈΠΏΡ ΡΡΠ°ΡΡΡΠ΅Ρ uvicorn Π² ΡΠΎΠ½Π΅, ΠΆΠ΄ΡΡ ΠΏΠΎΠΊΠ° /health Π½Π°ΡΠ½ΡΡ ΠΎΡΠ²Π΅ΡΠ°ΡΡ
|
| 7 |
+
(ΠΌΠΎΠ΄Π΅Π»Ρ Qwen ΠΌΠΎΠΆΠ΅Ρ ΠΊΠ°ΡΠ°ΡΡΡΡ 5β10 ΠΌΠΈΠ½ΡΡ ΠΏΡΠΈ ΠΏΠ΅ΡΠ²ΠΎΠΌ Π·Π°ΠΏΡΡΠΊΠ΅), ΠΈ
|
| 8 |
+
ΠΏΠΎΠ΄Π½ΠΈΠΌΠ°Π΅Ρ Streamlit ΠΊΠ°ΠΊ ΠΎΡΠ½ΠΎΠ²Π½ΠΎΠΉ ΠΏΡΠΎΡΠ΅ΡΡ. ΠΠΎ Ctrl+C ΠΊΠΎΡΡΠ΅ΠΊΡΠ½ΠΎ
|
| 9 |
+
ΠΎΡΡΠ°Π½Π°Π²Π»ΠΈΠ²Π°Π΅Ρ ΠΎΠ±Π°.
|
| 10 |
+
|
| 11 |
+
ΠΠΎΠ»Π΅Π·Π½ΠΎ ΠΏΡΠΈ ΡΠ°Π·ΡΠ°Π±ΠΎΡΠΊΠ΅ ΠΈ Π΄Π»Ρ Π΄Π΅ΠΌΠΎ Π½Π° Π·Π°ΡΠΈΡΠ΅: ΠΎΠ΄ΠΈΠ½ Ctrl+C β ΠΈ ΠΎΠ±Π°
|
| 12 |
+
ΠΏΡΠΎΡΠ΅ΡΡΠ° Π·Π°Π²Π΅ΡΡΠ΅Π½Ρ, Π½Π΅ Π½ΡΠΆΠ½ΠΎ ΠΈΡΠΊΠ°ΡΡ Π²ΠΈΡΡΡΠΈΠΉ uvicorn Π² taskmgr.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import signal
|
| 19 |
+
import subprocess
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import httpx
|
| 25 |
+
|
| 26 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 27 |
+
API_HOST = os.environ.get("API_HOST", "127.0.0.1")
|
| 28 |
+
API_PORT = int(os.environ.get("API_PORT", "8000"))
|
| 29 |
+
STREAMLIT_PORT = int(os.environ.get("STREAMLIT_PORT", "8501"))
|
| 30 |
+
# Streamlit ΡΠ»ΡΡΠ°Π΅Ρ 0.0.0.0 Π² ΠΊΠΎΠ½ΡΠ΅ΠΉΠ½Π΅ΡΠ΅ (Π½ΡΠΆΠ΅Π½ Π²Π½Π΅ΡΠ½ΠΈΠΉ Π΄ΠΎΡΡΡΠΏ),
|
| 31 |
+
# Π° Π»ΠΎΠΊΠ°Π»ΡΠ½ΠΎ ΠΏΠΎ ΡΠΌΠΎΠ»ΡΠ°Π½ΠΈΡ 127.0.0.1 β ΡΡΠΎ ΡΠΏΡΠ°Π²Π»ΡΠ΅ΡΡΡ ΠΎΡΠ΄Π΅Π»ΡΠ½ΠΎΠΉ ΠΏΠ΅ΡΠ΅ΠΌΠ΅Π½Π½ΠΎΠΉ.
|
| 32 |
+
STREAMLIT_HOST = os.environ.get("STREAMLIT_HOST", "127.0.0.1")
|
| 33 |
+
API_URL = f"http://{API_HOST}:{API_PORT}"
|
| 34 |
+
WAIT_API_SECONDS = 600 # ΠΌΠΎΠ΄Π΅Π»Ρ ΡΠΊΠ°ΡΠΈΠ²Π°Π΅ΡΡΡ Π΄ΠΎΠ»Π³ΠΎ
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def wait_for_api(timeout: int) -> bool:
|
| 38 |
+
"""ΠΠΏΡΠ°ΡΠΈΠ²Π°Π΅ΠΌ /health, ΠΏΠΎΠΊΠ° Π½Π΅ ΠΏΠΎΠ»ΡΡΠΈΠΌ 200 ΠΈΠ»ΠΈ Π½Π΅ Π²ΡΠΉΠ΄Π΅Ρ timeout."""
|
| 39 |
+
deadline = time.time() + timeout
|
| 40 |
+
while time.time() < deadline:
|
| 41 |
+
try:
|
| 42 |
+
r = httpx.get(f"{API_URL}/health", timeout=2.0)
|
| 43 |
+
if r.status_code == 200:
|
| 44 |
+
return True
|
| 45 |
+
except Exception:
|
| 46 |
+
pass
|
| 47 |
+
time.sleep(1.5)
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main():
|
| 52 |
+
print(f"[run_app] ΠΊΠΎΡΠ΅Π½Ρ ΠΏΡΠΎΠ΅ΠΊΡΠ°: {ROOT}")
|
| 53 |
+
print(f"[run_app] ΡΡΠ°ΡΡΡΡ uvicorn Π½Π° {API_URL}")
|
| 54 |
+
|
| 55 |
+
creation_flags = 0
|
| 56 |
+
if sys.platform == "win32":
|
| 57 |
+
creation_flags = subprocess.CREATE_NEW_PROCESS_GROUP
|
| 58 |
+
|
| 59 |
+
uvicorn_proc = subprocess.Popen(
|
| 60 |
+
[
|
| 61 |
+
sys.executable, "-m", "uvicorn", "src.api.main:app",
|
| 62 |
+
"--host", API_HOST,
|
| 63 |
+
"--port", str(API_PORT),
|
| 64 |
+
],
|
| 65 |
+
cwd=str(ROOT),
|
| 66 |
+
creationflags=creation_flags,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
print(f"[run_app] ΠΆΠ΄Ρ Π³ΠΎΡΠΎΠ²Π½ΠΎΡΡΠΈ API ({WAIT_API_SECONDS} ΡΠ΅ΠΊ ΠΌΠ°ΠΊΡ) β "
|
| 71 |
+
"ΠΏΡΠΈ ΠΏΠ΅ΡΠ²ΠΎΠΌ Π·Π°ΠΏΡΡΠΊΠ΅ Qwen ΠΊΠ°ΡΠ°Π΅ΡΡΡ Ρ HuggingFace")
|
| 72 |
+
if not wait_for_api(WAIT_API_SECONDS):
|
| 73 |
+
print("[run_app] API Π½Π΅ ΠΏΠΎΠ΄Π½ΡΠ»ΡΡ Π·Π° ΠΎΡΠ²Π΅Π΄ΡΠ½Π½ΠΎΠ΅ Π²ΡΠ΅ΠΌΡ. ΠΡΠΎΠ²Π΅ΡΡ Π»ΠΎΠ³ΠΈ uvicorn.")
|
| 74 |
+
uvicorn_proc.terminate()
|
| 75 |
+
sys.exit(1)
|
| 76 |
+
print(f"[run_app] API Π³ΠΎΡΠΎΠ², ΡΡΠ°ΡΡΡΡ Streamlit Π½Π° http://{STREAMLIT_HOST}:{STREAMLIT_PORT}")
|
| 77 |
+
|
| 78 |
+
env = os.environ.copy()
|
| 79 |
+
env["RU2SQL_API_URL"] = API_URL
|
| 80 |
+
streamlit_cmd = [
|
| 81 |
+
sys.executable, "-m", "streamlit", "run", "streamlit_app.py",
|
| 82 |
+
"--server.port", str(STREAMLIT_PORT),
|
| 83 |
+
"--server.address", STREAMLIT_HOST,
|
| 84 |
+
]
|
| 85 |
+
streamlit_proc = subprocess.Popen(streamlit_cmd, cwd=str(ROOT), env=env)
|
| 86 |
+
|
| 87 |
+
# ΠΠ»Π°Π²Π½ΡΠΉ ΡΠΈΠΊΠ» β ΠΆΠ΄ΡΠΌ, ΠΏΠΎΠΊΠ° streamlit ΠΈΠ»ΠΈ uvicorn Π½Π΅ ΡΠΏΠ°Π΄ΡΡ
|
| 88 |
+
try:
|
| 89 |
+
while True:
|
| 90 |
+
if uvicorn_proc.poll() is not None:
|
| 91 |
+
print("[run_app] uvicorn Π·Π°Π²Π΅ΡΡΠΈΠ»ΡΡ, ΠΎΡΡΠ°Π½Π°Π²Π»ΠΈΠ²Π°Ρ streamlit")
|
| 92 |
+
streamlit_proc.terminate()
|
| 93 |
+
break
|
| 94 |
+
if streamlit_proc.poll() is not None:
|
| 95 |
+
print("[run_app] streamlit Π·Π°Π²Π΅ΡΡΠΈΠ»ΡΡ, ΠΎΡΡΠ°Π½Π°Π²Π»ΠΈΠ²Π°Ρ uvicorn")
|
| 96 |
+
break
|
| 97 |
+
time.sleep(1.0)
|
| 98 |
+
except KeyboardInterrupt:
|
| 99 |
+
print("\n[run_app] Ctrl+C ΠΏΠΎΠ»ΡΡΠ΅Π½, Π·Π°Π²Π΅ΡΡΠ°Ρ ΠΏΡΠΎΡΠ΅ΡΡΡβ¦")
|
| 100 |
+
streamlit_proc.terminate()
|
| 101 |
+
finally:
|
| 102 |
+
try:
|
| 103 |
+
uvicorn_proc.terminate()
|
| 104 |
+
uvicorn_proc.wait(timeout=5)
|
| 105 |
+
except subprocess.TimeoutExpired:
|
| 106 |
+
uvicorn_proc.kill()
|
| 107 |
+
except Exception:
|
| 108 |
+
pass
|
| 109 |
+
print("[run_app] ΠΎΠ±Π° ΠΏΡΠΎΡΠ΅ΡΡΠ° ΠΎΡΡΠ°Π½ΠΎΠ²Π»Π΅Π½Ρ")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
main()
|
scripts/smoke_local.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ΠΠΎΠΊΠ°Π»ΡΠ½ΡΠΉ smoke-ΡΠ΅ΡΡ ΡΠ°Π±ΠΎΡΠΎΡΠΏΠΎΡΠΎΠ±Π½ΠΎΡΡΠΈ Ru2SQL.
|
| 2 |
+
|
| 3 |
+
ΠΡΠΎΠ³ΠΎΠ½ΡΠ΅Ρ ΠΊΠ»ΡΡΠ΅Π²ΡΠ΅ ΡΠ»ΠΎΠΈ Π±Π΅Π· ΠΏΠΎΠ΄Π½ΡΡΠΈΡ Streamlit/FastAPI ΠΊΠ°ΠΊ ΠΎΡΠ΄Π΅Π»ΡΠ½ΡΡ
|
| 4 |
+
ΠΏΡΠΎΡΠ΅ΡΡΠΎΠ². ΠΠΎΠΊΡΡΠ²Π°Π΅Ρ: ΠΈΠΌΠΏΠΎΡΡΡ, demo-Π±Π°Π·Ρ, vocabulary, prompt,
|
| 5 |
+
ΠΏΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΡ, guardrail, Π½ΠΎΠ²ΡΠΉ API ΡΠ΅ΡΠ΅Π· FastAPI TestClient ΠΈ
|
| 6 |
+
ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎ β ΠΊΠΎΡΠΎΡΠΊΠΈΠΉ ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ ΡΠ΅Π°Π»ΡΠ½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ.
|
| 7 |
+
|
| 8 |
+
ΠΠ°ΠΏΡΡΠΊ:
|
| 9 |
+
python scripts/smoke_local.py
|
| 10 |
+
python scripts/smoke_local.py --with-model # ΠΌΠ΅Π΄Π»Π΅Π½Π½ΠΎ, Π³ΡΡΠ·ΠΈΡ Qwen
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import sys
|
| 17 |
+
import time
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 21 |
+
sys.path.insert(0, str(ROOT))
|
| 22 |
+
|
| 23 |
+
GREEN = "\033[32m"
|
| 24 |
+
RED = "\033[31m"
|
| 25 |
+
YELLOW = "\033[33m"
|
| 26 |
+
RESET = "\033[0m"
|
| 27 |
+
|
| 28 |
+
results: list[tuple[str, str, str]] = []
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check(name: str):
|
| 32 |
+
def decorator(fn):
|
| 33 |
+
try:
|
| 34 |
+
t0 = time.time()
|
| 35 |
+
fn()
|
| 36 |
+
dt = time.time() - t0
|
| 37 |
+
results.append(("OK", name, f"{dt*1000:.0f} ΠΌΡ"))
|
| 38 |
+
except Exception as e:
|
| 39 |
+
results.append(("FAIL", name, f"{type(e).__name__}: {e}"))
|
| 40 |
+
return fn
|
| 41 |
+
return decorator
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
# Π‘Π»ΠΎΠΉ 1 β ΠΈΠΌΠΏΠΎΡΡΡ, ΠΊΠΎΠ½ΡΠΈΠ³
|
| 46 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
@check("ΠΠΌΠΏΠΎΡΡΡ ΡΠ΄ΡΠ°")
|
| 48 |
+
def _():
|
| 49 |
+
from src.config import settings
|
| 50 |
+
from src.api.schemas import GenerateRequest, QueryRequest, SchemaRequest
|
| 51 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 52 |
+
from src.data.prompt import build_chat_messages, BASE_SYSTEM_PROMPT
|
| 53 |
+
from src.data.schema_provider import (
|
| 54 |
+
SpiderSchemaProvider, ConnectionSchemaProvider, TableSchema,
|
| 55 |
+
)
|
| 56 |
+
from src.db.connector import DbConnector
|
| 57 |
+
from src.db.executor import SqlExecutor
|
| 58 |
+
from src.models.postprocess import postprocess, is_select_only, is_valid_sql
|
| 59 |
+
assert settings.base_model_name.startswith("Qwen")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
# Π‘Π»ΠΎΠΉ 2 β Π΄Π΅ΠΌΠΎ-Π±Π°Π·Π° ΠΈ read-only guardrail
|
| 64 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
+
@check("ΠΠ΅ΠΌΠΎ-Π±Π°Π·Π° sales.sqlite: ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ ΠΈ ΡΡ
Π΅ΠΌΠ°")
|
| 66 |
+
def _():
|
| 67 |
+
from src.db.connector import DbConnector
|
| 68 |
+
db = ROOT / "data" / "demo" / "sales.sqlite"
|
| 69 |
+
assert db.exists(), f"ΠΠ΅ Π½Π°ΠΉΠ΄Π΅Π½ ΡΠ°ΠΉΠ» {db} β Π·Π°ΠΏΡΡΡΠΈ data/demo/create_demo_db.py"
|
| 70 |
+
conn = DbConnector(str(db))
|
| 71 |
+
tables = conn.list_tables()
|
| 72 |
+
assert set(tables) == {"customers", "managers", "products", "orders", "order_items"}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@check("ΠΠ΅ΠΌΠΎ-Π±Π°Π·Π°: ΡΠ΅Π°Π»ΡΠ½ΡΠΉ SELECT Π²ΡΡΡΡΠΊΠΈ")
|
| 76 |
+
def _():
|
| 77 |
+
from src.db.executor import SqlExecutor
|
| 78 |
+
ex = SqlExecutor(str(ROOT / "data" / "demo" / "sales.sqlite"))
|
| 79 |
+
res = ex.run("SELECT SUM(amount) FROM orders WHERE status='paid'")
|
| 80 |
+
assert res.success, res.error
|
| 81 |
+
assert res.rows[0][0] > 0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@check("Read-only guardrail: DELETE ΠΎΡΠΊΠ»ΠΎΠ½ΡΠ΅ΡΡΡ Π΄ΡΠ°ΠΉΠ²Π΅ΡΠΎΠΌ")
|
| 85 |
+
def _():
|
| 86 |
+
from src.db.executor import SqlExecutor
|
| 87 |
+
ex = SqlExecutor(str(ROOT / "data" / "demo" / "sales.sqlite"))
|
| 88 |
+
res = ex.run("DELETE FROM orders")
|
| 89 |
+
assert not res.success
|
| 90 |
+
assert ex.run("SELECT COUNT(*) FROM orders").rows[0][0] == 120
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 94 |
+
# Π‘Π»ΠΎΠΉ 3 β BusinessVocabulary ΠΈΠ· YAML
|
| 95 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
@check("BusinessVocabulary ΠΈΠ· configs/example_vocabulary.yaml")
|
| 97 |
+
def _():
|
| 98 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 99 |
+
vocab = BusinessVocabulary.from_yaml(ROOT / "configs" / "example_vocabulary.yaml")
|
| 100 |
+
assert bool(vocab)
|
| 101 |
+
assert "Π²ΡΡΡΡΠΊΠ°" in vocab.terms
|
| 102 |
+
assert "SUM(orders.amount)" in vocab.render_system_context()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
+
# Π‘Π»ΠΎΠΉ 4 β PromptBuilder Ρ ΡΠ΅Π°Π»ΡΠ½ΠΎΠΉ ΡΡ
Π΅ΠΌΠΎΠΉ ΠΈ ΡΠ»ΠΎΠ²Π°ΡΡΠΌ
|
| 107 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 108 |
+
@check("PromptBuilder: vocabulary ΡΡ
ΠΎΠ΄ΠΈΡ Π² system, Π½Π΅ Π² user")
|
| 109 |
+
def _():
|
| 110 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 111 |
+
from src.data.prompt import build_chat_messages
|
| 112 |
+
from src.db.connector import DbConnector
|
| 113 |
+
vocab = BusinessVocabulary.from_yaml(ROOT / "configs" / "example_vocabulary.yaml")
|
| 114 |
+
schema = DbConnector(str(ROOT / "data" / "demo" / "sales.sqlite")).render_schema()
|
| 115 |
+
msgs = build_chat_messages(schema, "ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ° Π·Π° 2026 Π³ΠΎΠ΄?", vocabulary=vocab)
|
| 116 |
+
assert "SUM(orders.amount)" in msgs[0]["content"]
|
| 117 |
+
assert "SUM(orders.amount)" not in msgs[1]["content"]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 121 |
+
# Π‘Π»ΠΎΠΉ 5 β ΠΏΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠ° ΠΈ guardrail
|
| 122 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 123 |
+
@check("ΠΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠ°: markdown, ΠΏΡΠ΅ΡΠΈΠΊΡΡ, truncated, ΠΌΡΡΠΎΡ")
|
| 124 |
+
def _():
|
| 125 |
+
from src.models.postprocess import postprocess, is_select_only
|
| 126 |
+
cases = [
|
| 127 |
+
("```sql\nSELECT 1;\n```", lambda s: s.upper().startswith("SELECT")),
|
| 128 |
+
("ΠΡΠ²Π΅Ρ: SELECT name FROM customers;", lambda s: s.startswith("SELECT name")),
|
| 129 |
+
("SELECT * FROM orders WHERE", lambda s: s == ""),
|
| 130 |
+
("ΠΏΡΠΎΡΡΠΎ ΡΠ΅ΠΊΡΡ Π±Π΅Π· SQL", lambda s: s == ""),
|
| 131 |
+
]
|
| 132 |
+
for raw, ok in cases:
|
| 133 |
+
assert ok(postprocess(raw)), f"Π‘Π±ΠΎΠΉ Π½Π°: {raw!r} β {postprocess(raw)!r}"
|
| 134 |
+
assert is_select_only("SELECT 1") is True
|
| 135 |
+
assert is_select_only("DROP TABLE x") is False
|
| 136 |
+
assert is_select_only("DELETE FROM x") is False
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
# Π‘Π»ΠΎΠΉ 6 β FastAPI ΡΠ΅ΡΠ΅Π· TestClient Ρ Π·Π°ΠΌΠΎΠΊΠ°Π½Π½ΡΠΌ InferenceEngine
|
| 141 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
@check("FastAPI: /health, /schema, /query Ρ ΠΏΠΎΠ΄ΠΌΠ΅Π½ΡΠ½Π½ΡΠΌ engine")
|
| 143 |
+
def _():
|
| 144 |
+
try:
|
| 145 |
+
from fastapi.testclient import TestClient
|
| 146 |
+
except ImportError:
|
| 147 |
+
raise AssertionError("ΠΠΎΡΡΠ°Π²Ρ fastapi[all] ΠΈΠ»ΠΈ httpx β TestClient Π½Π΅Π΄ΠΎΡΡΡΠΏΠ΅Π½")
|
| 148 |
+
|
| 149 |
+
from src.api import dependencies as deps
|
| 150 |
+
from src.api.main import app
|
| 151 |
+
from src.models.inference import GenerationResult
|
| 152 |
+
|
| 153 |
+
class FakeEngine:
|
| 154 |
+
loaded = True
|
| 155 |
+
base_model_name = "Qwen/Qwen2.5-Coder-3B-Instruct"
|
| 156 |
+
def generate(self, schema, question, vocabulary=None, **kw):
|
| 157 |
+
# ΠΠΌΡΠ»ΠΈΡΡΠ΅ΠΌ Π²Π°Π»ΠΈΠ΄Π½ΡΠΉ SQL β pipeline Π΄ΠΎΠ»ΠΆΠ΅Π½ ΠΏΡΠΎΠΏΡΡΡΠΈΡΡ ΡΠ΅ΡΠ΅Π· guardrail.
|
| 158 |
+
sql = "SELECT SUM(amount) FROM orders WHERE status = 'paid'"
|
| 159 |
+
return GenerationResult(sql=sql, raw_output=sql)
|
| 160 |
+
|
| 161 |
+
app.dependency_overrides[deps.get_engine] = lambda: FakeEngine()
|
| 162 |
+
try:
|
| 163 |
+
with TestClient(app) as client:
|
| 164 |
+
# /health
|
| 165 |
+
r = client.get("/health")
|
| 166 |
+
assert r.status_code == 200
|
| 167 |
+
assert r.json()["model_loaded"] is True
|
| 168 |
+
|
| 169 |
+
# /schema Π½Π° ΡΠ΅Π°Π»ΡΠ½ΠΎΠΉ demo-Π±Π°Π·Π΅
|
| 170 |
+
db_path = ROOT / "data" / "demo" / "sales.sqlite"
|
| 171 |
+
r = client.post("/schema", json={
|
| 172 |
+
"connection_string": str(db_path),
|
| 173 |
+
"include_samples": True,
|
| 174 |
+
})
|
| 175 |
+
assert r.status_code == 200, r.text
|
| 176 |
+
tables = r.json()["tables"]
|
| 177 |
+
assert {t["name"] for t in tables} == {
|
| 178 |
+
"customers", "managers", "products", "orders", "order_items"
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# /query Π½Π° ΡΠ΅Π°Π»ΡΠ½ΠΎΠΉ demo-Π±Π°Π·Π΅, FakeEngine ΠΎΡΠ΄Π°ΡΡ Π²Π°Π»ΠΈΠ΄Π½ΡΠΉ SELECT
|
| 182 |
+
r = client.post("/query", json={
|
| 183 |
+
"question": "ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ° Π·Π° ΠΎΠΏΠ»Π°ΡΠ΅Π½Π½ΡΠ΅ Π·Π°ΠΊΠ°Π·Ρ?",
|
| 184 |
+
"connection_string": str(db_path),
|
| 185 |
+
"execute": True,
|
| 186 |
+
"vocabulary": {
|
| 187 |
+
"company": "ΠΠ΅ΠΌΠΎ",
|
| 188 |
+
"terms": {"Π²ΡΡΡΡΠΊΠ°": "SUM(amount) WHERE status='paid'"},
|
| 189 |
+
},
|
| 190 |
+
})
|
| 191 |
+
assert r.status_code == 200, r.text
|
| 192 |
+
body = r.json()
|
| 193 |
+
assert body["is_valid_sql"] is True
|
| 194 |
+
assert body["execution"] is not None
|
| 195 |
+
assert body["execution"]["rows"][0][0] > 0
|
| 196 |
+
|
| 197 |
+
# /query Ρ DELETE β Π΄ΠΎΠ»ΠΆΠ΅Π½ Π±ΡΡΡ ΠΎΡΠ±ΠΈΡ guardrail'ΠΎΠΌ
|
| 198 |
+
class DropEngine(FakeEngine):
|
| 199 |
+
def generate(self, *a, **kw):
|
| 200 |
+
return GenerationResult(
|
| 201 |
+
sql="DELETE FROM orders WHERE id=1",
|
| 202 |
+
raw_output="DELETE FROM orders WHERE id=1",
|
| 203 |
+
)
|
| 204 |
+
app.dependency_overrides[deps.get_engine] = lambda: DropEngine()
|
| 205 |
+
r = client.post("/query", json={
|
| 206 |
+
"question": "Π£Π΄Π°Π»ΠΈ Π·Π°ΠΊΠ°Π· 1",
|
| 207 |
+
"connection_string": str(db_path),
|
| 208 |
+
"execute": True,
|
| 209 |
+
})
|
| 210 |
+
assert r.status_code == 200, r.text
|
| 211 |
+
body = r.json()
|
| 212 |
+
assert body["execution"] is None
|
| 213 |
+
assert body["error"] and "Π³Π²Π°ΡΠ΄Π΅ΠΉΠ»" in body["error"].lower()
|
| 214 |
+
finally:
|
| 215 |
+
app.dependency_overrides.clear()
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 219 |
+
# Π‘Π»ΠΎΠΉ 7 β ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΡΠΉ ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ ΡΠ΅Π°Π»ΡΠ½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
|
| 220 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 221 |
+
def run_model_smoke():
|
| 222 |
+
@check("InferenceEngine: Π·Π°Π³ΡΡΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈ ΠΎΠ΄Π½Π° Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ")
|
| 223 |
+
def _():
|
| 224 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 225 |
+
from src.db.connector import DbConnector
|
| 226 |
+
from src.models.inference import InferenceEngine
|
| 227 |
+
|
| 228 |
+
engine = InferenceEngine()
|
| 229 |
+
engine.load()
|
| 230 |
+
assert engine.loaded
|
| 231 |
+
schema = DbConnector(str(ROOT / "data" / "demo" / "sales.sqlite")).render_schema()
|
| 232 |
+
vocab = BusinessVocabulary.from_yaml(ROOT / "configs" / "example_vocabulary.yaml")
|
| 233 |
+
res = engine.generate(schema, "ΠΠ°ΠΊΠ°Ρ ΡΡΠΌΠΌΠ°ΡΠ½Π°Ρ Π²ΡΡΡΡΠΊΠ° Π·Π° 2026 Π³ΠΎΠ΄?", vocab)
|
| 234 |
+
assert res.sql and res.sql.upper().startswith("SELECT"), \
|
| 235 |
+
f"ΠΠΎΠ΄Π΅Π»Ρ Π²Π΅ΡΠ½ΡΠ»Π°: {res.raw_output!r}"
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 239 |
+
# main
|
| 240 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 241 |
+
def main():
|
| 242 |
+
parser = argparse.ArgumentParser()
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--with-model", action="store_true",
|
| 245 |
+
help="ΠΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΠΎ ΠΏΡΠΎΠ³Π½Π°ΡΡ ΡΠ΅Π°Π»ΡΠ½ΡΠΉ ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ Qwen (ΠΌΠ΅Π΄Π»Π΅Π½Π½ΠΎ, ~30 ΡΠ΅ΠΊ Π½Π° CPU)",
|
| 246 |
+
)
|
| 247 |
+
args = parser.parse_args()
|
| 248 |
+
|
| 249 |
+
if args.with_model:
|
| 250 |
+
run_model_smoke()
|
| 251 |
+
|
| 252 |
+
print()
|
| 253 |
+
print("=" * 64)
|
| 254 |
+
print("Smoke-ΠΏΡΠΎΠ²Π΅ΡΠΊΠ° Ru2SQL")
|
| 255 |
+
print("=" * 64)
|
| 256 |
+
for status, name, info in results:
|
| 257 |
+
color = GREEN if status == "OK" else RED
|
| 258 |
+
mark = "β" if status == "OK" else "β"
|
| 259 |
+
print(f" {color}{mark}{RESET} {name} {YELLOW}[{info}]{RESET}")
|
| 260 |
+
ok = sum(1 for s, _, _ in results if s == "OK")
|
| 261 |
+
print("=" * 64)
|
| 262 |
+
summary_color = GREEN if ok == len(results) else RED
|
| 263 |
+
print(f" {summary_color}{ok} / {len(results)} ΠΏΡΠΎΠ²Π΅ΡΠΎΠΊ ΠΏΡΠΎΠΉΠ΄Π΅Π½ΠΎ{RESET}")
|
| 264 |
+
print("=" * 64)
|
| 265 |
+
if ok < len(results):
|
| 266 |
+
print()
|
| 267 |
+
print("ΠΠΎΠ΄ΡΠΊΠ°Π·ΠΊΠ°: Π·Π°ΠΏΡΡΡΠΈ 'pytest -v' Π΄Π»Ρ ΠΏΠΎΠ΄ΡΠΎΠ±Π½ΡΡ
Π΄ΠΈΠ°Π³Π½ΠΎΡΡΠΈΠΊ.")
|
| 268 |
+
sys.exit(0 if ok == len(results) else 1)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
main()
|
src/api/main.py
CHANGED
|
@@ -3,11 +3,21 @@
|
|
| 3 |
ΠΠ°ΠΏΡΡΠΊ:
|
| 4 |
uvicorn src.api.main:app --reload
|
| 5 |
# Swagger UI: http://127.0.0.1:8000/docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
|
|
|
| 10 |
import sqlite3
|
|
|
|
| 11 |
|
| 12 |
from fastapi import Depends, FastAPI, HTTPException
|
| 13 |
from fastapi.concurrency import run_in_threadpool
|
|
@@ -19,29 +29,59 @@ from src.api.schemas import (
|
|
| 19 |
GenerateRequest,
|
| 20 |
GenerateResponse,
|
| 21 |
HealthResponse,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
)
|
|
|
|
| 23 |
from src.config import settings
|
| 24 |
from src.data.schema import SchemaRetriever
|
|
|
|
|
|
|
| 25 |
from src.models.inference import InferenceEngine
|
| 26 |
-
from src.models.postprocess import is_valid_sql
|
|
|
|
|
|
|
| 27 |
|
| 28 |
app = FastAPI(
|
| 29 |
title="ru2sql",
|
| 30 |
description="ΠΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ Π²ΠΎΠΏΡΠΎΡΠΎΠ² Π½Π° ΡΡΡΡΠΊΠΎΠΌ Π² SQL-Π·Π°ΠΏΡΠΎΡΡ",
|
| 31 |
-
version="0.
|
| 32 |
lifespan=lifespan,
|
| 33 |
)
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
@app.get("/health", response_model=HealthResponse)
|
| 37 |
def health(engine: InferenceEngine = Depends(get_engine)):
|
| 38 |
return HealthResponse(
|
| 39 |
status="ok",
|
| 40 |
-
model_loaded=engine.
|
| 41 |
base_model=engine.base_model_name,
|
| 42 |
)
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
@app.get("/databases", response_model=list[DatabaseInfo])
|
| 46 |
def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
|
| 47 |
out: list[DatabaseInfo] = []
|
|
@@ -54,21 +94,33 @@ def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
|
|
| 54 |
return out
|
| 55 |
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
@app.post("/generate-sql", response_model=GenerateResponse)
|
| 58 |
async def generate_sql(
|
| 59 |
req: GenerateRequest,
|
| 60 |
engine: InferenceEngine = Depends(get_engine),
|
| 61 |
retriever: SchemaRetriever = Depends(get_schema_retriever),
|
| 62 |
):
|
|
|
|
| 63 |
try:
|
| 64 |
schema_text = retriever.render_schema(req.db_id)
|
| 65 |
except FileNotFoundError as e:
|
| 66 |
raise HTTPException(status_code=404, detail=str(e)) from e
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
|
|
|
|
|
|
|
|
|
|
| 71 |
valid = is_valid_sql(result.sql)
|
|
|
|
| 72 |
response = GenerateResponse(
|
| 73 |
sql=result.sql,
|
| 74 |
raw_output=result.raw_output,
|
|
@@ -76,9 +128,15 @@ async def generate_sql(
|
|
| 76 |
)
|
| 77 |
|
| 78 |
if req.execute and valid:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
try:
|
| 80 |
response.execution = await run_in_threadpool(
|
| 81 |
-
|
| 82 |
)
|
| 83 |
except sqlite3.Error as e:
|
| 84 |
response.error = f"SQL execution error: {e}"
|
|
@@ -86,7 +144,7 @@ async def generate_sql(
|
|
| 86 |
return response
|
| 87 |
|
| 88 |
|
| 89 |
-
def
|
| 90 |
db_path = retriever.db_path(db_id)
|
| 91 |
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
| 92 |
try:
|
|
@@ -95,16 +153,102 @@ def _execute_sql(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionR
|
|
| 95 |
cur.execute(sql)
|
| 96 |
rows = cur.fetchmany(100)
|
| 97 |
cols = [d[0] for d in cur.description] if cur.description else []
|
| 98 |
-
return ExecutionResult(
|
| 99 |
-
columns=cols,
|
| 100 |
-
rows=[list(r) for r in rows],
|
| 101 |
-
row_count=len(rows),
|
| 102 |
-
)
|
| 103 |
finally:
|
| 104 |
conn.close()
|
| 105 |
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
if __name__ == "__main__":
|
| 108 |
import uvicorn
|
| 109 |
-
|
| 110 |
uvicorn.run("src.api.main:app", host=settings.api_host, port=settings.api_port, reload=True)
|
|
|
|
| 3 |
ΠΠ°ΠΏΡΡΠΊ:
|
| 4 |
uvicorn src.api.main:app --reload
|
| 5 |
# Swagger UI: http://127.0.0.1:8000/docs
|
| 6 |
+
|
| 7 |
+
ΠΠ½Π΄ΠΏΠΎΠΈΠ½ΡΡ:
|
| 8 |
+
GET /health β ΡΡΠ°ΡΡΡ ΡΠ΅ΡΠ²ΠΈΡΠ° ΠΈ Π·Π°Π³ΡΡΠΆΠ΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
|
| 9 |
+
GET /databases β ΡΠΏΠΈΡΠΎΠΊ ΠΠ ΠΈΠ· data/databases (PAUQ-ΡΡΡΡΠΊΡΡΡΠ°)
|
| 10 |
+
POST /generate-sql β Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ SQL ΠΏΠΎ db_id ΠΈΠ· PAUQ
|
| 11 |
+
POST /schema β ΡΡ
Π΅ΠΌΠ° ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ ΠΏΠΎ connection string
|
| 12 |
+
POST /query β ΠΏΠΎΠ»Π½ΡΠΉ pipeline Π΄Π»Ρ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ
|
| 13 |
+
(Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ + ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎΠ΅ ΠΈΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ + guardrail)
|
| 14 |
"""
|
| 15 |
|
| 16 |
from __future__ import annotations
|
| 17 |
|
| 18 |
+
import logging
|
| 19 |
import sqlite3
|
| 20 |
+
import time
|
| 21 |
|
| 22 |
from fastapi import Depends, FastAPI, HTTPException
|
| 23 |
from fastapi.concurrency import run_in_threadpool
|
|
|
|
| 29 |
GenerateRequest,
|
| 30 |
GenerateResponse,
|
| 31 |
HealthResponse,
|
| 32 |
+
QueryRequest,
|
| 33 |
+
QueryResponse,
|
| 34 |
+
SchemaRequest,
|
| 35 |
+
SchemaResponse,
|
| 36 |
+
TablePayload,
|
| 37 |
+
ColumnPayload,
|
| 38 |
)
|
| 39 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 40 |
from src.config import settings
|
| 41 |
from src.data.schema import SchemaRetriever
|
| 42 |
+
from src.data.schema_provider import ConnectionSchemaProvider
|
| 43 |
+
from src.db.executor import SqlExecutor
|
| 44 |
from src.models.inference import InferenceEngine
|
| 45 |
+
from src.models.postprocess import is_select_only, is_valid_sql
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
|
| 49 |
app = FastAPI(
|
| 50 |
title="ru2sql",
|
| 51 |
description="ΠΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ Π²ΠΎΠΏΡΠΎΡΠΎΠ² Π½Π° ΡΡΡΡΠΊΠΎΠΌ Π² SQL-Π·Π°ΠΏΡΠΎΡΡ",
|
| 52 |
+
version="0.2.0",
|
| 53 |
lifespan=lifespan,
|
| 54 |
)
|
| 55 |
|
| 56 |
|
| 57 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
# ΠΠ°Π·ΠΎΠ²ΡΠ΅ ΡΠ½Π΄ΠΏΠΎΠΈΠ½ΡΡ
|
| 59 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
+
|
| 61 |
@app.get("/health", response_model=HealthResponse)
|
| 62 |
def health(engine: InferenceEngine = Depends(get_engine)):
|
| 63 |
return HealthResponse(
|
| 64 |
status="ok",
|
| 65 |
+
model_loaded=engine.loaded,
|
| 66 |
base_model=engine.base_model_name,
|
| 67 |
)
|
| 68 |
|
| 69 |
|
| 70 |
+
@app.post("/warmup")
|
| 71 |
+
async def warmup(engine: InferenceEngine = Depends(get_engine)):
|
| 72 |
+
"""ΠΡΠΎΠ³ΡΠ΅Π²Π°Π΅Ρ ΠΌΠΎΠ΄Π΅Π»Ρ ΠΎΠ΄Π½ΠΎΠΉ ΠΊΠΎΡΠΎΡΠΊΠΎΠΉ Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠ΅ΠΉ.
|
| 73 |
+
|
| 74 |
+
ΠΠ΅ΡΠ²ΡΠΉ ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ Ρ
ΠΎΠ»ΠΎΠ΄Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ Π½Π° CPU ΡΠΈΠ»ΡΠ½ΠΎ Π΄ΠΎΠ»ΡΡΠ΅ ΠΏΠΎΡΠ»Π΅Π΄ΡΡΡΠΈΡ
:
|
| 75 |
+
ΠΏΠΎΠ΄Π³ΡΡΠΆΠ°ΡΡΡΡ LoRA-ΡΠ»ΠΎΠΈ, ΡΠΎΡΠΌΠΈΡΡΠ΅ΡΡΡ Π³ΡΠ°Ρ Π²ΡΡΠΈΡΠ»Π΅Π½ΠΈΠΉ, Π·Π°ΠΏΠΎΠ»Π½ΡΠ΅ΡΡΡ
|
| 76 |
+
KV-ΠΊΠ΅Ρ. ΠΡΠ·ΠΎΠ² /warmup Π΄Π΅Π»Π°Π΅Ρ ΠΎΠ΄ΠΈΠ½ ΠΌΠ°Π»Π΅Π½ΡΠΊΠΈΠΉ ΠΏΡΠΎΡ
ΠΎΠ΄ Ρ ΠΌΠΈΠ½ΠΈΠΌΠ°Π»ΡΠ½ΡΠΌ
|
| 77 |
+
max_new_tokens, ΡΡΠΎΠ±Ρ Π±ΠΎΠ΅Π²ΠΎΠΉ /query ΡΡΠ» ΡΠΆΠ΅ ΠΏΠΎ ΠΏΡΠΎΠ³ΡΠ΅ΡΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ.
|
| 78 |
+
"""
|
| 79 |
+
t0 = time.time()
|
| 80 |
+
schema = "CREATE TABLE t (id INT);"
|
| 81 |
+
await run_in_threadpool(engine.generate, schema, "SELECT id", None, 16)
|
| 82 |
+
return {"warmup_seconds": round(time.time() - t0, 2)}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
@app.get("/databases", response_model=list[DatabaseInfo])
|
| 86 |
def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
|
| 87 |
out: list[DatabaseInfo] = []
|
|
|
|
| 94 |
return out
|
| 95 |
|
| 96 |
|
| 97 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
+
# PAUQ-ΡΡΠ΅Π½Π°ΡΠΈΠΉ (ΡΡΠ°ΡΡΠΉ ΡΠ½Π΄ΠΏΠΎΠΈΠ½Ρ, ΠΎΡΡΠ°Π²Π»Π΅Π½ Π΄Π»Ρ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠΎΡΡΠΈ)
|
| 99 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
|
| 101 |
@app.post("/generate-sql", response_model=GenerateResponse)
|
| 102 |
async def generate_sql(
|
| 103 |
req: GenerateRequest,
|
| 104 |
engine: InferenceEngine = Depends(get_engine),
|
| 105 |
retriever: SchemaRetriever = Depends(get_schema_retriever),
|
| 106 |
):
|
| 107 |
+
"""ΠΠ΅Π½Π΅ΡΠ°ΡΠΈΡ SQL Π΄Π»Ρ Π±Π°Π·Ρ ΠΈΠ· PAUQ-ΡΡΡΡΠΊΡΡΡΡ (data/databases/{db_id})."""
|
| 108 |
try:
|
| 109 |
schema_text = retriever.render_schema(req.db_id)
|
| 110 |
except FileNotFoundError as e:
|
| 111 |
raise HTTPException(status_code=404, detail=str(e)) from e
|
| 112 |
|
| 113 |
+
vocab = (
|
| 114 |
+
BusinessVocabulary.from_dict(req.vocabulary.model_dump())
|
| 115 |
+
if req.vocabulary
|
| 116 |
+
else None
|
| 117 |
+
)
|
| 118 |
|
| 119 |
+
result = await run_in_threadpool(
|
| 120 |
+
engine.generate, schema_text, req.question, vocab
|
| 121 |
+
)
|
| 122 |
valid = is_valid_sql(result.sql)
|
| 123 |
+
|
| 124 |
response = GenerateResponse(
|
| 125 |
sql=result.sql,
|
| 126 |
raw_output=result.raw_output,
|
|
|
|
| 128 |
)
|
| 129 |
|
| 130 |
if req.execute and valid:
|
| 131 |
+
if not is_select_only(result.sql):
|
| 132 |
+
response.error = (
|
| 133 |
+
"SQL ΠΎΡΠΊΠ»ΠΎΠ½ΡΠ½ Π³Π²Π°ΡΠ΄Π΅ΠΉΠ»ΠΎΠΌ: ΡΠ°Π·ΡΠ΅ΡΠ΅Π½Ρ ΡΠΎΠ»ΡΠΊΠΎ Π·Π°ΠΏΡΠΎΡΡ SELECT ΠΈ WITH."
|
| 134 |
+
)
|
| 135 |
+
logger.warning("Guardrail ΠΎΡΠΊΠ»ΠΎΠ½ΠΈΠ» SQL: %r", result.sql[:120])
|
| 136 |
+
return response
|
| 137 |
try:
|
| 138 |
response.execution = await run_in_threadpool(
|
| 139 |
+
_execute_sql_pauq, req.db_id, result.sql, retriever
|
| 140 |
)
|
| 141 |
except sqlite3.Error as e:
|
| 142 |
response.error = f"SQL execution error: {e}"
|
|
|
|
| 144 |
return response
|
| 145 |
|
| 146 |
|
| 147 |
+
def _execute_sql_pauq(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionResult:
|
| 148 |
db_path = retriever.db_path(db_id)
|
| 149 |
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
| 150 |
try:
|
|
|
|
| 153 |
cur.execute(sql)
|
| 154 |
rows = cur.fetchmany(100)
|
| 155 |
cols = [d[0] for d in cur.description] if cur.description else []
|
| 156 |
+
return ExecutionResult(columns=cols, rows=[list(r) for r in rows], row_count=len(rows))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
finally:
|
| 158 |
conn.close()
|
| 159 |
|
| 160 |
|
| 161 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 162 |
+
# ΠΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½Π°Ρ ΠΠ ΠΏΠΎ connection string β Π½ΠΎΠ²ΡΠΉ ΡΡΠ΅Π½Π°ΡΠΈΠΉ Π΄Π»Ρ Streamlit
|
| 163 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 164 |
+
|
| 165 |
+
@app.post("/schema", response_model=SchemaResponse)
|
| 166 |
+
async def get_schema(req: SchemaRequest):
|
| 167 |
+
"""ΠΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ ΡΡ
Π΅ΠΌΡ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ Π΄Π»Ρ ΠΎΡΠΎΠ±ΡΠ°ΠΆΠ΅Π½ΠΈΡ Π² ΠΊΠ»ΠΈΠ΅Π½ΡΠ΅."""
|
| 168 |
+
try:
|
| 169 |
+
provider = ConnectionSchemaProvider(req.connection_string)
|
| 170 |
+
tables = await run_in_threadpool(provider.get_tables, 2 if req.include_samples else 0)
|
| 171 |
+
except Exception as e: # noqa: BLE001
|
| 172 |
+
raise HTTPException(status_code=400, detail=f"ΠΡΠΈΠ±ΠΊΠ° ΡΡΠ΅Π½ΠΈΡ ΡΡ
Π΅ΠΌΡ: {e}") from e
|
| 173 |
+
|
| 174 |
+
payload = [
|
| 175 |
+
TablePayload(
|
| 176 |
+
name=t.name,
|
| 177 |
+
columns=[
|
| 178 |
+
ColumnPayload(
|
| 179 |
+
name=c.name, type=c.type,
|
| 180 |
+
nullable=c.nullable, primary_key=c.primary_key,
|
| 181 |
+
)
|
| 182 |
+
for c in t.columns
|
| 183 |
+
],
|
| 184 |
+
sample_rows=[list(r) for r in t.sample_rows],
|
| 185 |
+
ddl=t.to_ddl(),
|
| 186 |
+
)
|
| 187 |
+
for t in tables
|
| 188 |
+
]
|
| 189 |
+
return SchemaResponse(tables=payload)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@app.post("/query", response_model=QueryResponse)
|
| 193 |
+
async def query(
|
| 194 |
+
req: QueryRequest,
|
| 195 |
+
engine: InferenceEngine = Depends(get_engine),
|
| 196 |
+
):
|
| 197 |
+
"""ΠΠΎΠ»Π½ΡΠΉ pipeline: Π²ΠΎΠΏΡΠΎΡ β SQL β ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎΠ΅ ΠΈΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ Π½Π° ΠΠ.
|
| 198 |
+
|
| 199 |
+
Π ΠΎΡΠ»ΠΈΡΠΈΠ΅ ΠΎΡ /generate-sql, ΡΠ°Π±ΠΎΡΠ°Π΅Ρ Ρ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ ΠΏΠΎ connection
|
| 200 |
+
string. ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ Streamlit-ΠΊΠ»ΠΈΠ΅Π½ΡΠΎΠΌ ΠΈ ΡΡΠΎΡΠΎΠ½Π½ΠΈΠΌΠΈ ΠΈΠ½ΡΠ΅Π³ΡΠ°ΡΠΈΡΠΌΠΈ.
|
| 201 |
+
ΠΠ΅ΡΠ΅Π΄ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ΠΌ SQL ΠΏΡΠΎΡ
ΠΎΠ΄ΠΈΡ ΠΏΡΠΎΠ²Π΅ΡΠΊΡ is_select_only (ΡΠ°Π·Π΄Π΅Π» 4.4).
|
| 202 |
+
"""
|
| 203 |
+
# 1. Π‘Ρ
Π΅ΠΌΠ° ΡΠ΅Π»Π΅Π²ΠΎΠΉ ΠΠ
|
| 204 |
+
try:
|
| 205 |
+
provider = ConnectionSchemaProvider(req.connection_string)
|
| 206 |
+
schema_text = await run_in_threadpool(provider.render_schema, True)
|
| 207 |
+
except Exception as e: # noqa: BLE001
|
| 208 |
+
raise HTTPException(status_code=400, detail=f"ΠΡΠΈΠ±ΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ ΠΊ ΠΠ: {e}") from e
|
| 209 |
+
|
| 210 |
+
vocab = (
|
| 211 |
+
BusinessVocabulary.from_dict(req.vocabulary.model_dump())
|
| 212 |
+
if req.vocabulary
|
| 213 |
+
else None
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# 2. ΠΠ½ΡΠ΅ΡΠ΅Π½Ρ
|
| 217 |
+
t0 = time.time()
|
| 218 |
+
result = await run_in_threadpool(engine.generate, schema_text, req.question, vocab)
|
| 219 |
+
gen_time = time.time() - t0
|
| 220 |
+
|
| 221 |
+
valid = is_valid_sql(result.sql)
|
| 222 |
+
response = QueryResponse(
|
| 223 |
+
sql=result.sql,
|
| 224 |
+
raw_output=result.raw_output,
|
| 225 |
+
is_valid_sql=valid,
|
| 226 |
+
gen_time_seconds=round(gen_time, 2),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# 3. ΠΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎΠ΅ ΠΈΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅
|
| 230 |
+
if req.execute and valid:
|
| 231 |
+
if not is_select_only(result.sql):
|
| 232 |
+
response.error = (
|
| 233 |
+
"SQL ΠΎΡΠΊΠ»ΠΎΠ½ΡΠ½ Π³Π²Π°ΡΠ΄Π΅ΠΉΠ»ΠΎΠΌ: ΡΠ°Π·ΡΠ΅ΡΠ΅Π½Ρ οΏ½οΏ½ΠΎΠ»ΡΠΊΠΎ Π·Π°ΠΏΡΠΎΡΡ SELECT ΠΈ WITH."
|
| 234 |
+
)
|
| 235 |
+
logger.warning("Guardrail ΠΎΡΠΊΠ»ΠΎΠ½ΠΈΠ» SQL: %r", result.sql[:120])
|
| 236 |
+
return response
|
| 237 |
+
try:
|
| 238 |
+
executor = SqlExecutor(req.connection_string)
|
| 239 |
+
qr = await run_in_threadpool(executor.run, result.sql)
|
| 240 |
+
if qr.success:
|
| 241 |
+
response.execution = ExecutionResult(
|
| 242 |
+
columns=qr.columns, rows=qr.rows, row_count=qr.row_count,
|
| 243 |
+
)
|
| 244 |
+
else:
|
| 245 |
+
response.error = f"SQL execution error: {qr.error}"
|
| 246 |
+
except Exception as e: # noqa: BLE001
|
| 247 |
+
response.error = f"SQL execution error: {e}"
|
| 248 |
+
|
| 249 |
+
return response
|
| 250 |
+
|
| 251 |
+
|
| 252 |
if __name__ == "__main__":
|
| 253 |
import uvicorn
|
|
|
|
| 254 |
uvicorn.run("src.api.main:app", host=settings.api_host, port=settings.api_port, reload=True)
|
src/api/schemas.py
CHANGED
|
@@ -5,10 +5,30 @@ from __future__ import annotations
|
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
class GenerateRequest(BaseModel):
|
| 9 |
question: str = Field(..., min_length=1, max_length=2000, description="ΠΠΎΠΏΡΠΎΡ Π½Π° ΡΡΡΡΠΊΠΎΠΌ")
|
| 10 |
db_id: str = Field(..., min_length=1, description="ΠΠ΄Π΅Π½ΡΠΈΡΠΈΠΊΠ°ΡΠΎΡ ΠΠ ΠΈΠ· PAUQ")
|
| 11 |
execute: bool = Field(default=False, description="ΠΡΠΎΠ³Π½Π°ΡΡ ΡΠ³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°Π½Π½ΡΠΉ SQL Π½Π° ΠΠ")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class ExecutionResult(BaseModel):
|
|
@@ -25,6 +45,66 @@ class GenerateResponse(BaseModel):
|
|
| 25 |
error: str | None = None
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
class DatabaseInfo(BaseModel):
|
| 29 |
db_id: str
|
| 30 |
tables: list[str]
|
|
|
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
|
| 7 |
|
| 8 |
+
class VocabularyPayload(BaseModel):
|
| 9 |
+
"""ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ Π² ΡΠΎΡΠΌΠ°ΡΠ΅, ΡΠΎΠ³Π»Π°ΡΠΎΠ²Π°Π½Π½ΠΎΠΌ Ρ BusinessVocabulary.from_dict.
|
| 10 |
+
|
| 11 |
+
ΠΠ΅ΡΠ΅Π΄Π°ΡΡΡΡ ΠΊΠ»ΠΈΠ΅Π½ΡΠΎΠΌ ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎ Π² Π·Π°ΠΏΡΠΎΡΠ΅ Π½Π° Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ SQL. ΠΡΠ»ΠΈ ΠΏΠΎΠ»Π΅
|
| 12 |
+
ΠΎΡΡΡΡΡΡΠ²ΡΠ΅Ρ β ΠΌΠΎΠ΄Π΅Π»Ρ ΡΠ°Π±ΠΎΡΠ°Π΅Ρ Π±Π΅Π· ΡΠΏΠ΅ΡΠΈΠ°Π»ΡΠ½ΡΡ
Π±ΠΈΠ·Π½Π΅Ρ-ΠΎΠΏΡΠ΅Π΄Π΅Π»Π΅Π½ΠΈΠΉ.
|
| 13 |
+
"""
|
| 14 |
+
company: str = ""
|
| 15 |
+
terms: dict[str, str] = Field(default_factory=dict)
|
| 16 |
+
filters: dict[str, str] = Field(default_factory=dict)
|
| 17 |
+
notes: list[str] = Field(default_factory=list)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
# /generate-sql β ΡΡΠ°ΡΡΠΉ ΡΠ½Π΄ΠΏΠΎΠΈΠ½Ρ Π΄Π»Ρ PAUQ-ΡΡΡΡΠΊΡΡΡΡ (databases_dir + db_id)
|
| 22 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
|
| 24 |
class GenerateRequest(BaseModel):
|
| 25 |
question: str = Field(..., min_length=1, max_length=2000, description="ΠΠΎΠΏΡΠΎΡ Π½Π° ΡΡΡΡΠΊΠΎΠΌ")
|
| 26 |
db_id: str = Field(..., min_length=1, description="ΠΠ΄Π΅Π½ΡΠΈΡΠΈΠΊΠ°ΡΠΎΡ ΠΠ ΠΈΠ· PAUQ")
|
| 27 |
execute: bool = Field(default=False, description="ΠΡΠΎΠ³Π½Π°ΡΡ ΡΠ³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°Π½Π½ΡΠΉ SQL Π½Π° ΠΠ")
|
| 28 |
+
vocabulary: VocabularyPayload | None = Field(
|
| 29 |
+
default=None,
|
| 30 |
+
description="ΠΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΡΠΉ Π±ΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ (ΡΠΌ. ΡΠ°Π·Π΄Π΅Π» 3.6 ΠΠΠ )",
|
| 31 |
+
)
|
| 32 |
|
| 33 |
|
| 34 |
class ExecutionResult(BaseModel):
|
|
|
|
| 45 |
error: str | None = None
|
| 46 |
|
| 47 |
|
| 48 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
# /query β Π½ΠΎΠ²ΡΠΉ ΡΠ½Π΄ΠΏΠΎΠΈΠ½Ρ Π΄Π»Ρ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ (connection string)
|
| 50 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
|
| 52 |
+
class QueryRequest(BaseModel):
|
| 53 |
+
"""ΠΠΎΠ»Π½ΡΠΉ Π·Π°ΠΏΡΠΎΡ Β«Π²ΠΎΠΏΡΠΎΡ Π½Π° ΡΡΡΡΠΊΠΎΠΌ β SQL β ΡΠ΅Π·ΡΠ»ΡΡΠ°ΡΒ» Π΄Π»Ρ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ.
|
| 54 |
+
|
| 55 |
+
Π ΠΎΡΠ»ΠΈΡΠΈΠ΅ ΠΎΡ GenerateRequest, Π½Π΅ ΠΏΡΠΈΠ²ΡΠ·Π°Π½ ΠΊ PAUQ-ΡΡΡΡΠΊΡΡΡΠ΅: ΠΊΠ»ΠΈΠ΅Π½Ρ ΡΠ°ΠΌ
|
| 56 |
+
ΠΏΠ΅ΡΠ΅Π΄Π°ΡΡ connection string (SQLite/PostgreSQL/MySQL). ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ
|
| 57 |
+
Streamlit-ΠΈΠ½ΡΠ΅ΡΡΠ΅ΠΉΡΠΎΠΌ ΠΈ Π»ΡΠ±ΡΠΌΠΈ ΡΡΠΎΡΠΎΠ½Π½ΠΈΠΌΠΈ ΠΊΠ»ΠΈΠ΅Π½ΡΠ°ΠΌΠΈ.
|
| 58 |
+
"""
|
| 59 |
+
question: str = Field(..., min_length=1, max_length=2000)
|
| 60 |
+
connection_string: str = Field(
|
| 61 |
+
..., min_length=1,
|
| 62 |
+
description="Π‘ΡΡΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ. ΠΡΠΈΠΌΠ΅Ρ: sqlite:///data/demo/sales.sqlite",
|
| 63 |
+
)
|
| 64 |
+
execute: bool = Field(default=True)
|
| 65 |
+
vocabulary: VocabularyPayload | None = None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class QueryResponse(BaseModel):
|
| 69 |
+
sql: str
|
| 70 |
+
raw_output: str
|
| 71 |
+
is_valid_sql: bool
|
| 72 |
+
gen_time_seconds: float
|
| 73 |
+
execution: ExecutionResult | None = None
|
| 74 |
+
error: str | None = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
# /schema β ΠΎΡΠ΄Π°ΡΡ ΡΡ
Π΅ΠΌΡ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ Π΄Π»Ρ ΠΏΠΎΠ΄ΡΡΠ°Π½ΠΎΠ²ΠΊΠΈ Π² UI
|
| 79 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
|
| 81 |
+
class SchemaRequest(BaseModel):
|
| 82 |
+
connection_string: str = Field(..., min_length=1)
|
| 83 |
+
include_samples: bool = Field(default=True)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ColumnPayload(BaseModel):
|
| 87 |
+
name: str
|
| 88 |
+
type: str
|
| 89 |
+
nullable: bool
|
| 90 |
+
primary_key: bool
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TablePayload(BaseModel):
|
| 94 |
+
name: str
|
| 95 |
+
columns: list[ColumnPayload]
|
| 96 |
+
sample_rows: list[list]
|
| 97 |
+
ddl: str
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class SchemaResponse(BaseModel):
|
| 101 |
+
tables: list[TablePayload]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
# ΠΡΠΎΡΠ΅Π΅
|
| 106 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
|
| 108 |
class DatabaseInfo(BaseModel):
|
| 109 |
db_id: str
|
| 110 |
tables: list[str]
|
src/data/prompt.py
CHANGED
|
@@ -1,8 +1,26 @@
|
|
| 1 |
-
"""PromptBuilder β ΡΠΎΡΠΌΠΈΡ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"Π’Ρ β Π°ΡΡΠΈΡΡΠ΅Π½Ρ, ΠΊΠΎΡΠΎΡΡΠΉ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΡΠ΅Ρ Π²ΠΎΠΏΡΠΎΡΡ Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² ΠΊΠΎΡΡΠ΅ΠΊΡΠ½ΡΠ΅ SQL-Π·Π°ΠΏΡΠΎΡΡ. "
|
| 7 |
"Π’Π΅Π±Π΅ Π΄Π°ΡΡΡΡ ΡΡ
Π΅ΠΌΠ° Π±Π°Π·Ρ Π΄Π°Π½Π½ΡΡ
Π² Π²ΠΈΠ΄Π΅ CREATE TABLE statements ΠΈ ΠΏΡΠΈΠΌΠ΅Ρ Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΈΡ
ΡΡΡΠΎΠΊ. "
|
| 8 |
"Π‘Π³Π΅Π½Π΅ΡΠΈΡΡΠΉ ΠΎΠ΄ΠΈΠ½ SQL-Π·Π°ΠΏΡΠΎΡ, ΠΊΠΎΡΠΎΡΡΠΉ ΠΎΡΠ²Π΅ΡΠ°Π΅Ρ Π½Π° Π²ΠΎΠΏΡΠΎΡ ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»Ρ. "
|
|
@@ -11,19 +29,63 @@ SYSTEM_PROMPT = (
|
|
| 11 |
|
| 12 |
|
| 13 |
def build_user_message(schema: str, question: str) -> str:
|
|
|
|
| 14 |
return f"### Schema:\n{schema}\n\n### Question:\n{question}\n\n### SQL:\n"
|
| 15 |
|
| 16 |
|
| 17 |
-
def
|
| 18 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
return [
|
| 20 |
-
{"role": "system", "content":
|
| 21 |
{"role": "user", "content": build_user_message(schema, question)},
|
| 22 |
]
|
| 23 |
|
| 24 |
|
| 25 |
-
def build_training_example(
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
msgs.append({"role": "assistant", "content": sql.strip()})
|
| 29 |
return msgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PromptBuilder β ΡΠΎΡΠΌΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ chat-template input Π΄Π»Ρ ΠΌΠΎΠ΄Π΅Π»ΠΈ.
|
| 2 |
+
|
| 3 |
+
Π‘ΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΠ΅Ρ ΡΠ°Π·Π΄Π΅Π»Ρ 2.4 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ. ΠΠ΄ΠΈΠ½ ΠΈ ΡΠΎΡ ΠΆΠ΅ Π±ΠΈΠ»Π΄Π΅Ρ
|
| 4 |
+
ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ ΠΏΡΠΈ ΠΎΠ±ΡΡΠ΅Π½ΠΈΠΈ (ΡΠΎΡΠΌΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ SFT-ΠΏΡΠΈΠΌΠ΅ΡΠΎΠ²) ΠΈ ΠΏΡΠΈ ΠΈΠ½ΡΠ΅ΡΠ΅Π½ΡΠ΅
|
| 5 |
+
(ΡΠΎΡΠΌΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ Π·Π°ΠΏΡΠΎΡΠ° ΠΊ Π·Π°Π³ΡΡΠΆΠ΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ), ΡΡΠΎ Π³Π°ΡΠ°Π½ΡΠΈΡΡΠ΅Ρ ΡΠΎΠ²ΠΏΠ°Π΄Π΅Π½ΠΈΠ΅
|
| 6 |
+
ΡΠΎΡΠΌΠ°ΡΠ° train- ΠΈ inference-time ΠΏΡΠΎΠΌΠΏΡΠΎΠ².
|
| 7 |
+
|
| 8 |
+
ΠΠΎΠΌΠΈΠΌΠΎ ΡΡ
Π΅ΠΌΡ ΠΈ Π²ΠΎΠΏΡΠΎΡΠ°, Π±ΠΈΠ»Π΄Π΅Ρ ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎ ΠΏΡΠΈΠ½ΠΈΠΌΠ°Π΅Ρ BusinessVocabulary
|
| 9 |
+
(ΡΠ°Π·Π΄Π΅Π» 3.6 ΠΠΠ ). ΠΠΈΠ·Π½Π΅Ρ-ΡΠ΅ΡΠΌΠΈΠ½Ρ ΠΏΠΎΠ΄ΠΌΠ΅ΡΠΈΠ²Π°ΡΡΡΡ Π² ΡΠΈΡΡΠ΅ΠΌΠ½ΠΎΠ΅ ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠ΅,
|
| 10 |
+
Π° Π½Π΅ ΠΊΠΎΠ½ΠΊΠ°ΡΠ΅Π½ΠΈΡΡΡΡΡΡ ΠΊ ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»ΡΡΠΊΠΎΠΌΡ Π²ΠΎΠΏΡΠΎΡΡ β ΡΡΠΎ ΡΠΎΠ³Π»Π°ΡΡΠ΅ΡΡΡ Ρ
|
| 11 |
+
ΡΠ΅ΠΌ, ΠΊΠ°ΠΊ ΡΠΎΠ²ΡΠ΅ΠΌΠ΅Π½Π½ΡΠ΅ instruction-tuned ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈΠ½ΡΠ΅ΡΠΏΡΠ΅ΡΠΈΡΡΡΡ ΡΠΎΠ»ΠΈ
|
| 12 |
+
ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠΉ Π² chat-template.
|
| 13 |
+
"""
|
| 14 |
|
| 15 |
from __future__ import annotations
|
| 16 |
|
| 17 |
+
from typing import TYPE_CHECKING
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
BASE_SYSTEM_PROMPT = (
|
| 24 |
"Π’Ρ β Π°ΡΡΠΈΡΡΠ΅Π½Ρ, ΠΊΠΎΡΠΎΡΡΠΉ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΡΠ΅Ρ Π²ΠΎΠΏΡΠΎΡΡ Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² ΠΊΠΎΡΡΠ΅ΠΊΡΠ½ΡΠ΅ SQL-Π·Π°ΠΏΡΠΎΡΡ. "
|
| 25 |
"Π’Π΅Π±Π΅ Π΄Π°ΡΡΡΡ ΡΡ
Π΅ΠΌΠ° Π±Π°Π·Ρ Π΄Π°Π½Π½ΡΡ
Π² Π²ΠΈΠ΄Π΅ CREATE TABLE statements ΠΈ ΠΏΡΠΈΠΌΠ΅Ρ Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΈΡ
ΡΡΡΠΎΠΊ. "
|
| 26 |
"Π‘Π³Π΅Π½Π΅ΡΠΈΡΡΠΉ ΠΎΠ΄ΠΈΠ½ SQL-Π·Π°ΠΏΡΠΎΡ, ΠΊΠΎΡΠΎΡΡΠΉ ΠΎΡΠ²Π΅ΡΠ°Π΅Ρ Π½Π° Π²ΠΎΠΏΡΠΎΡ ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»Ρ. "
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def build_user_message(schema: str, question: str) -> str:
|
| 32 |
+
"""ΠΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»ΡΡΠΊΠ°Ρ ΡΠ°ΡΡΡ ΠΏΡΠΎΠΌΠΏΡΠ° Π² ΡΠΎΡΠΌΠ°ΡΠ΅ ``### Schema / ### Question / ### SQL:``."""
|
| 33 |
return f"### Schema:\n{schema}\n\n### Question:\n{question}\n\n### SQL:\n"
|
| 34 |
|
| 35 |
|
| 36 |
+
def build_system_message(vocabulary: "BusinessVocabulary | None" = None) -> str:
|
| 37 |
+
"""Π‘ΠΎΠ±ΠΈΡΠ°Π΅Ρ ΡΠΈΡΡΠ΅ΠΌΠ½ΠΎΠ΅ ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠ΅.
|
| 38 |
+
|
| 39 |
+
ΠΡΠ»ΠΈ ΠΏΠ΅ΡΠ΅Π΄Π°Π½ Π½Π΅ΠΏΡΡΡΠΎΠΉ Π±ΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ, ΠΊ Π±Π°Π·ΠΎΠ²ΠΎΠΌΡ ΠΏΡΠΎΠΌΠΏΡΡ Π΄ΠΎΠ±Π°Π²Π»ΡΠ΅ΡΡΡ
|
| 40 |
+
Π±Π»ΠΎΠΊ Ρ ΠΎΠΏΡΠ΅Π΄Π΅Π»Π΅Π½ΠΈΡΠΌΠΈ ΡΠ΅ΡΠΌΠΈΠ½ΠΎΠ², ΡΠΈΠ»ΡΡΡΠ°ΠΌΠΈ ΠΈ ΠΏΡΠ°Π²ΠΈΠ»Π°ΠΌΠΈ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ. ΠΡΠΎ
|
| 41 |
+
ΠΏΠΎΠ·Π²ΠΎΠ»ΡΠ΅Ρ Π°Π΄Π°ΠΏΡΠΈΡΠΎΠ²Π°ΡΡ ΡΠΈΡΡΠ΅ΠΌΡ ΠΊ ΡΠ΅ΡΠΌΠΈΠ½ΠΎΠ»ΠΎΠ³ΠΈΠΈ ΠΊΠΎΠ½ΠΊΡΠ΅ΡΠ½ΠΎΠΉ ΠΎΡΠ³Π°Π½ΠΈΠ·Π°ΡΠΈΠΈ
|
| 42 |
+
Π±Π΅Π· ΠΏΠΎΠ²ΡΠΎΡΠ½ΠΎΠ³ΠΎ Π΄ΠΎΠΎΠ±ΡΡΠ΅Π½ΠΈΡ ΠΌΠΎΠ΄Π΅Π»ΠΈ.
|
| 43 |
+
"""
|
| 44 |
+
if vocabulary is None or not vocabulary:
|
| 45 |
+
return BASE_SYSTEM_PROMPT
|
| 46 |
+
context = vocabulary.render_system_context()
|
| 47 |
+
if not context:
|
| 48 |
+
return BASE_SYSTEM_PROMPT
|
| 49 |
+
return BASE_SYSTEM_PROMPT + "\n\n" + context
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_chat_messages(
|
| 53 |
+
schema: str,
|
| 54 |
+
question: str,
|
| 55 |
+
vocabulary: "BusinessVocabulary | None" = None,
|
| 56 |
+
) -> list[dict]:
|
| 57 |
+
"""Π‘ΠΎΠΎΠ±ΡΠ΅Π½ΠΈΡ Π΄Π»Ρ ``tokenizer.apply_chat_template``.
|
| 58 |
+
|
| 59 |
+
ΠΠ°ΡΠ°ΠΌΠ΅ΡΡΡ
|
| 60 |
+
---------
|
| 61 |
+
schema : str
|
| 62 |
+
Π’Π΅ΠΊΡΡΠΎΠ²ΠΎΠ΅ ΠΏΡΠ΅Π΄ΡΡΠ°Π²Π»Π΅Π½ΠΈΠ΅ ΡΡ
Π΅ΠΌΡ (CREATE TABLE + sample rows).
|
| 63 |
+
question : str
|
| 64 |
+
ΠΠΎΠΏΡΠΎΡ ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»Ρ Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅.
|
| 65 |
+
vocabulary : BusinessVocabulary, optional
|
| 66 |
+
ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ. ΠΡΠ»ΠΈ ΠΏΠ΅ΡΠ΅Π΄Π°Π½ β Π΄ΠΎΠ±Π°Π²Π»ΡΠ΅ΡΡΡ Π² ΡΠΈΡΡΠ΅ΠΌΠ½ΠΎΠ΅
|
| 67 |
+
ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠ΅, Π½Π΅ Π½Π°ΡΡΡΠ°Ρ ΡΡΡΡΠΊΡΡΡΡ ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»ΡΡΠΊΠΎΠΉ ΡΠ΅ΠΏΠ»ΠΈΠΊΠΈ.
|
| 68 |
+
"""
|
| 69 |
return [
|
| 70 |
+
{"role": "system", "content": build_system_message(vocabulary)},
|
| 71 |
{"role": "user", "content": build_user_message(schema, question)},
|
| 72 |
]
|
| 73 |
|
| 74 |
|
| 75 |
+
def build_training_example(
|
| 76 |
+
schema: str,
|
| 77 |
+
question: str,
|
| 78 |
+
sql: str,
|
| 79 |
+
vocabulary: "BusinessVocabulary | None" = None,
|
| 80 |
+
) -> list[dict]:
|
| 81 |
+
"""ΠΠΎΠ»Π½ΡΠΉ Π΄ΠΈΠ°Π»ΠΎΠ³ Ρ ΡΡΠ°Π»ΠΎΠ½Π½ΠΎΠΉ ΡΠ΅ΠΏΠ»ΠΈΠΊΠΎΠΉ Π°ΡΡΠΈΡΡΠ΅Π½ΡΠ° Π΄Π»Ρ Supervised
|
| 82 |
+
Fine-Tuning (ΡΠ°Π·Π΄Π΅Π» 2.4 ΠΠΠ ).
|
| 83 |
+
"""
|
| 84 |
+
msgs = build_chat_messages(schema, question, vocabulary)
|
| 85 |
msgs.append({"role": "assistant", "content": sql.strip()})
|
| 86 |
return msgs
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ΠΠ±ΡΠ°ΡΠ½Π°Ρ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠΎΡΡΡ. ΠΠΌΡ SYSTEM_PROMPT ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π»ΠΎΡΡ Π² ΡΡΠ°ΡΠΎΠΌ ΠΊΠΎΠ΄Π΅
|
| 90 |
+
# ΠΈ Π² ΡΠ΅ΡΡΠ°Ρ
; ΡΠΎΡ
ΡΠ°Π½ΡΠ΅ΠΌ Π°Π»ΠΈΠ°Ρ, ΡΡΠΎΠ±Ρ Π½Π΅ Π»ΠΎΠΌΠ°ΡΡ ΠΈΠΌΠΏΠΎΡΡΡ.
|
| 91 |
+
SYSTEM_PROMPT = BASE_SYSTEM_PROMPT
|
src/data/schema.py
CHANGED
|
@@ -1,76 +1,27 @@
|
|
| 1 |
-
"""SchemaRetriever β
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
import sqlite3
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
from pathlib import Path
|
| 8 |
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
create_sql: str
|
| 14 |
-
sample_rows: list[tuple]
|
| 15 |
|
| 16 |
|
| 17 |
-
class SchemaRetriever:
|
| 18 |
-
"""
|
| 19 |
|
| 20 |
def __init__(self, databases_dir: Path | str):
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def db_path(self, db_id: str) -> Path:
|
| 24 |
-
"""Π Spider/PAUQ ΠΊΠ°ΠΆΠ΄Π°Ρ ΠΠ Π»Π΅ΠΆΠΈΡ Π² databases_dir/{db_id}/{db_id}.sqlite."""
|
| 25 |
-
path = self.databases_dir / db_id / f"{db_id}.sqlite"
|
| 26 |
-
if not path.exists():
|
| 27 |
-
raise FileNotFoundError(f"Database file not found: {path}")
|
| 28 |
-
return path
|
| 29 |
-
|
| 30 |
-
def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableInfo]:
|
| 31 |
-
"""ΠΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ ΡΠΏΠΈΡΠΎΠΊ ΡΠ°Π±Π»ΠΈΡ Ρ CREATE-SQL ΠΈ ΠΏΡΠΈΠΌΠ΅ΡΠΎΠΌ ΡΡΡΠΎΠΊ."""
|
| 32 |
-
path = self.db_path(db_id)
|
| 33 |
-
conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
|
| 34 |
-
try:
|
| 35 |
-
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 36 |
-
cur = conn.cursor()
|
| 37 |
-
cur.execute(
|
| 38 |
-
"SELECT name, sql FROM sqlite_master "
|
| 39 |
-
"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
| 40 |
-
)
|
| 41 |
-
rows = cur.fetchall()
|
| 42 |
-
|
| 43 |
-
tables: list[TableInfo] = []
|
| 44 |
-
for table_name, create_sql in rows:
|
| 45 |
-
if not create_sql:
|
| 46 |
-
continue
|
| 47 |
-
try:
|
| 48 |
-
cur.execute(f'SELECT * FROM "{table_name}" LIMIT {n_sample_rows}')
|
| 49 |
-
samples = cur.fetchall()
|
| 50 |
-
except sqlite3.Error:
|
| 51 |
-
samples = []
|
| 52 |
-
tables.append(
|
| 53 |
-
TableInfo(name=table_name, create_sql=create_sql.strip(), sample_rows=samples)
|
| 54 |
-
)
|
| 55 |
-
return tables
|
| 56 |
-
finally:
|
| 57 |
-
conn.close()
|
| 58 |
-
|
| 59 |
-
def render_schema(self, db_id: str, include_samples: bool = True) -> str:
|
| 60 |
-
"""Π’Π΅ΠΊΡΡΠΎΠ²ΠΎΠ΅ ΠΏΡΠ΅Π΄ΡΡΠ°Π²Π»Π΅Π½ΠΈΠ΅ ΡΡ
Π΅ΠΌΡ Π΄Π»Ρ prompt'Π°."""
|
| 61 |
-
tables = self.get_tables(db_id)
|
| 62 |
-
parts: list[str] = []
|
| 63 |
-
for t in tables:
|
| 64 |
-
parts.append(t.create_sql + ";")
|
| 65 |
-
if include_samples and t.sample_rows:
|
| 66 |
-
parts.append(f"-- ΠΡΠΈΠΌΠ΅ΡΡ ΡΡΡΠΎΠΊ ΠΈΠ· {t.name}:")
|
| 67 |
-
for row in t.sample_rows:
|
| 68 |
-
parts.append(f"-- {row}")
|
| 69 |
-
parts.append("")
|
| 70 |
-
return "\n".join(parts).strip()
|
| 71 |
-
|
| 72 |
-
def list_databases(self) -> list[str]:
|
| 73 |
-
"""Π‘ΠΏΠΈΡΠΎΠΊ Π΄ΠΎΡΡΡΠΏΠ½ΡΡ
db_id."""
|
| 74 |
-
if not self.databases_dir.exists():
|
| 75 |
-
return []
|
| 76 |
-
return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())
|
|
|
|
| 1 |
+
"""SchemaRetriever β ΡΠ°ΡΠ°Π΄ Π½Π°Π΄ :class:`SpiderSchemaProvider`.
|
| 2 |
+
|
| 3 |
+
ΠΡΡΠΎΡΠΈΡΠ΅ΡΠΊΠΈ Π·Π΄Π΅ΡΡ ΠΆΠΈΠ»Π° ΠΏΠΎΠ»Π½Π°Ρ ΡΠ΅Π°Π»ΠΈΠ·Π°ΡΠΈΡ ΡΡΠ΅Π½ΠΈΡ PAUQ/Spider-ΡΡ
Π΅ΠΌ. ΠΠΎΡΠ»Π΅
|
| 4 |
+
ΡΠ΅ΡΠ°ΠΊΡΠΎΡΠΈΠ½Π³Π° (ΡΠ°Π·Π΄Π΅Π» 4.2 Π°ΡΠ΄ΠΈΡΠ°) ΠΎΠ±ΡΠ°Ρ Π»ΠΎΠ³ΠΈΠΊΠ° ΠΏΠ΅ΡΠ΅Π΅Ρ
Π°Π»Π° Π²
|
| 5 |
+
``schema_provider.py``, Π° ``SchemaRetriever`` ΠΎΡΡΠ°Π²Π»Π΅Π½ ΠΊΠ°ΠΊ ΡΠΎΠ½ΠΊΠ°Ρ
|
| 6 |
+
ΠΎΠ±ΡΡΡΠΊΠ° ΡΠ°Π΄ΠΈ ΠΎΠ±ΡΠ°ΡΠ½ΠΎΠΉ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠΎΡΡΠΈ ΠΈΠΌΠΏΠΎΡΡΠΎΠ² (``from src.data.schema
|
| 7 |
+
import SchemaRetriever``), ΠΊΠΎΡΠΎΡΡΠ΅ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΡΡΡΡ Π² API ΠΈ ΡΠ΅ΡΡΠ°Ρ
.
|
| 8 |
+
|
| 9 |
+
ΠΠΎΠ²ΡΠΉ ΠΊΠΎΠ΄ ΡΡΠΎΠΈΡ ΠΏΠΈΡΠ°ΡΡ ΡΡΠ°Π·Ρ ΡΠ΅ΡΠ΅Π· :class:`SpiderSchemaProvider`.
|
| 10 |
+
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
|
|
|
|
|
|
| 14 |
from pathlib import Path
|
| 15 |
|
| 16 |
+
from src.data.schema_provider import SpiderSchemaProvider, TableSchema
|
| 17 |
|
| 18 |
+
# ΠΠ»ΠΈΠ°Ρ Π΄Π»Ρ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠΎΡΡΠΈ ΡΠΎ ΡΡΠ°ΡΡΠΌΠΈ ΠΈΠΌΠΏΠΎΡΡΠ°ΠΌΠΈ Π²ΠΈΠ΄Π°
|
| 19 |
+
# ``from src.data.schema import TableInfo``.
|
| 20 |
+
TableInfo = TableSchema
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
+
class SchemaRetriever(SpiderSchemaProvider):
|
| 24 |
+
"""Π‘ΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΡΠΉ Π°Π»ΠΈΠ°Ρ :class:`SpiderSchemaProvider`."""
|
| 25 |
|
| 26 |
def __init__(self, databases_dir: Path | str):
|
| 27 |
+
super().__init__(databases_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data/schema_provider.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ΠΠ΄ΠΈΠ½ΡΠΉ ΠΈΠ½ΡΠ΅ΡΡΠ΅ΠΉΡ ΡΠ°Π±ΠΎΡΡ ΡΠΎ ΡΡ
Π΅ΠΌΠ°ΠΌΠΈ Π±Π°Π· Π΄Π°Π½Π½ΡΡ
.
|
| 2 |
+
|
| 3 |
+
ΠΠΎ ΡΠ΅ΡΠ°ΠΊΡΠΎΡΠΈΠ½Π³Π° Π² ΠΏΡΠΎΠ΅ΠΊΡΠ΅ ΡΡΡΠ΅ΡΡΠ²ΠΎΠ²Π°Π»ΠΈ Π΄Π²Π΅ Π½Π΅Π·Π°Π²ΠΈΡΠΈΠΌΡΠ΅ ΠΈΠ΅ΡΠ°ΡΡ
ΠΈΠΈ:
|
| 4 |
+
|
| 5 |
+
* ``SchemaRetriever`` (``src/data/schema.py``) β ΡΠΈΡΠ°Π» DDL ΠΈΠ· SQLite-ΡΠ°ΠΉΠ»ΠΎΠ²
|
| 6 |
+
Π² Spider/PAUQ-ΡΡΡΡΠΊΡΡΡΠ΅ ``{databases_dir}/{db_id}/{db_id}.sqlite``.
|
| 7 |
+
* ``DbConnector`` (``src/db/connector.py``) β ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ°Π»ΡΡ ΠΊ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ
|
| 8 |
+
ΠΏΠΎ ΡΡΡΠΎΠΊΠ΅ ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ, ΡΠΌΠ΅Π» SQLite/PostgreSQL/MySQL.
|
| 9 |
+
|
| 10 |
+
ΠΠ½ΠΈ ΡΠ΅ΡΠ°Π»ΠΈ ΠΎΠ΄Π½Ρ Π·Π°Π΄Π°ΡΡ, Π½ΠΎ ΠΏΠΎ-ΡΠ°Π·Π½ΠΎΠΌΡ ΠΎΡΠΎΡΠΌΠ»ΡΠ»ΠΈ ΡΠ΅Π·ΡΠ»ΡΡΠ°Ρ
|
| 11 |
+
(``TableInfo`` Π² ΠΊΠ°ΠΆΠ΄ΠΎΠΌ Π±ΡΠ» ΡΠ²ΠΎΠΉ) ΠΈ Π½Π΅ ΠΈΠΌΠ΅Π»ΠΈ ΠΎΠ±ΡΠ΅Π³ΠΎ ΠΈΠ½ΡΠ΅ΡΡΠ΅ΠΉΡΠ°. ΠΡΠΎΡ
|
| 12 |
+
ΠΌΠΎΠ΄ΡΠ»Ρ Π²Π²ΠΎΠ΄ΠΈΡ Π΅Π΄ΠΈΠ½ΡΠΉ ΠΏΡΠΎΡΠΎΠΊΠΎΠ» ``SchemaProvider`` ΠΈ ΠΎΠ±ΡΠΈΠΉ dataclass
|
| 13 |
+
``TableSchema``. Π‘ΡΠ°ΡΡΠ΅ ΠΊΠ»Π°ΡΡΡ ΡΡΠ°Π½ΠΎΠ²ΡΡΡΡ ΡΠΎΠ½ΠΊΠΈΠΌΠΈ ΡΠ°ΡΠ°Π΄Π°ΠΌΠΈ ΠΏΠΎΠ²Π΅ΡΡ
|
| 14 |
+
Π½ΠΎΠ²ΡΡ
ΡΠ΅Π°Π»ΠΈΠ·Π°ΡΠΈΠΉ.
|
| 15 |
+
|
| 16 |
+
Π‘ΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΠ΅Ρ ΡΠ°Π·Π΄Π΅Π»Π°ΠΌ 3.4 ΠΈ 4.1 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Iterable, Protocol
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class ColumnSchema:
|
| 28 |
+
"""ΠΠΏΠΈΡΠ°Π½ΠΈΠ΅ ΠΊΠΎΠ»ΠΎΠ½ΠΊΠΈ ΡΠ°Π±Π»ΠΈΡΡ."""
|
| 29 |
+
|
| 30 |
+
name: str
|
| 31 |
+
type: str
|
| 32 |
+
nullable: bool = True
|
| 33 |
+
primary_key: bool = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class TableSchema:
|
| 38 |
+
"""Π£Π½ΠΈΡΠΈΡΠΈΡΠΎΠ²Π°Π½Π½ΠΎΠ΅ ΠΎΠΏΠΈΡΠ°Π½ΠΈΠ΅ ΡΠ°Π±Π»ΠΈΡΡ Π½Π΅Π·Π°Π²ΠΈΡΠΈΠΌΠΎ ΠΎΡ ΠΈΡΡΠΎΡΠ½ΠΈΠΊΠ° ΡΡ
Π΅ΠΌΡ.
|
| 39 |
+
|
| 40 |
+
ΠΠΎΠ»Π΅ ``create_sql`` Ρ
ΡΠ°Π½ΠΈΡ ΠΈΡΡ
ΠΎΠ΄Π½ΡΠΉ CREATE TABLE statement, Π΅ΡΠ»ΠΈ ΠΎΠ½
|
| 41 |
+
Π΄ΠΎΡΡΡΠΏΠ΅Π½ (Π°ΠΊΡΡΠ°Π»ΡΠ½ΠΎ Π΄Π»Ρ SQLite β ΠΎΠ½ Π΅Π³ΠΎ ΡΠ°ΠΌ ΠΎΡΠ΄Π°ΡΡ ΠΈΠ· ``sqlite_master``).
|
| 42 |
+
ΠΠΎΠ³Π΄Π° ΠΈΡΡΠΎΡΠ½ΠΈΠΊ ΡΡ
Π΅ΠΌΡ β PostgreSQL/MySQL, DDL Π³Π΅Π½Π΅ΡΠΈΡΡΠ΅ΡΡΡ ΠΈΠ·
|
| 43 |
+
ΠΌΠ΅ΡΠ°Π΄Π°Π½Π½ΡΡ
ΡΠ΅ΡΠ΅Π· :meth:`to_ddl`.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
name: str
|
| 47 |
+
columns: list[ColumnSchema] = field(default_factory=list)
|
| 48 |
+
sample_rows: list[tuple] = field(default_factory=list)
|
| 49 |
+
create_sql: str | None = None
|
| 50 |
+
|
| 51 |
+
def to_ddl(self) -> str:
|
| 52 |
+
"""CREATE TABLE Π΄Π»Ρ ΠΏΠΎΠ΄ΡΡΠ°Π½ΠΎΠ²ΠΊΠΈ Π² ΠΏΡΠΎΠΌΠΏΡ.
|
| 53 |
+
|
| 54 |
+
ΠΡΠ»ΠΈ Π΅ΡΡΡ ΠΎΡΠΈΠ³ΠΈΠ½Π°Π»ΡΠ½ΡΠΉ ``create_sql`` β Π²ΠΎΠ·Π²ΡΠ°ΡΠ°Π΅ΠΌ Π΅Π³ΠΎ, ΡΡΠΎΠ±Ρ
|
| 55 |
+
ΡΠΎΡ
ΡΠ°Π½ΠΈΡΡ Π²ΡΠ΅ Π½ΡΠ°Π½ΡΡ (ΠΎΠ³ΡΠ°Π½ΠΈΡΠ΅Π½ΠΈΡ, FK, AUTOINCREMENT). ΠΠ½Π°ΡΠ΅
|
| 56 |
+
ΡΠΎΠ±ΠΈΡΠ°Π΅ΠΌ ΠΈΠ· ΠΌΠ΅ΡΠ°Π΄Π°Π½Π½ΡΡ
ΠΊΠΎΠ»ΠΎΠ½ΠΎΠΊ.
|
| 57 |
+
"""
|
| 58 |
+
if self.create_sql:
|
| 59 |
+
return self.create_sql.rstrip(";") + ";"
|
| 60 |
+
col_parts: list[str] = []
|
| 61 |
+
for col in self.columns:
|
| 62 |
+
line = f" {col.name} {col.type}"
|
| 63 |
+
if col.primary_key:
|
| 64 |
+
line += " PRIMARY KEY"
|
| 65 |
+
if not col.nullable:
|
| 66 |
+
line += " NOT NULL"
|
| 67 |
+
col_parts.append(line)
|
| 68 |
+
return f"CREATE TABLE {self.name} (\n" + ",\n".join(col_parts) + "\n);"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SchemaProvider(Protocol):
|
| 72 |
+
"""ΠΡΠΎΡΠΎΠΊΠΎΠ» Π»ΡΠ±ΠΎΠ³ΠΎ ΠΈΡΡΠΎΡΠ½ΠΈΠΊΠ° ΡΡ
Π΅ΠΌΡ Π±Π°Π·Ρ Π΄Π°Π½Π½ΡΡ
.
|
| 73 |
+
|
| 74 |
+
ΠΠΎΠ½ΡΡΠ°ΠΊΡ ΠΌΠΈΠ½ΠΈΠΌΠ°Π»ΡΠ½ΡΠΉ: ΡΠΌΠ΅ΡΡ ΠΏΠ΅ΡΠ΅ΡΠΈΡΠ»ΠΈΡΡ ΡΠ°Π±Π»ΠΈΡΡ ΠΈ ΠΎΡΡΠ΅Π½Π΄Π΅ΡΠΈΡΡ ΡΡ
Π΅ΠΌΡ
|
| 75 |
+
Π² ΡΠ΅ΠΊΡΡ Π΄Π»Ρ ΠΏΠΎΠ΄ΡΡΠ°Π½ΠΎΠ²ΠΊΠΈ Π² ΠΏΡΠΎΠΌΠΏΡ. ΠΡΠΎΠ³ΠΎ Π΄ΠΎΡΡΠ°ΡΠΎΡΠ½ΠΎ ΠΈ Π΄Π»Ρ PAUQ-ΡΡΠ΅Π½Π°ΡΠΈΡ
|
| 76 |
+
(``SpiderSchemaProvider``), ΠΈ Π΄Π»Ρ ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ ΠΊ Π±ΠΎΠ΅Π²ΠΎΠΉ ΠΠ ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»Ρ
|
| 77 |
+
(``ConnectionSchemaProvider``).
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def list_tables(self) -> list[str]: ...
|
| 81 |
+
def get_tables(self, n_sample_rows: int = 3) -> list[TableSchema]: ...
|
| 82 |
+
def render_schema(self, include_samples: bool = True) -> str: ...
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 86 |
+
# Π£ΡΠΈΠ»ΠΈΡΠ° ΡΠ΅Π½Π΄Π΅ΡΠΈΠ½Π³Π° β ΠΎΠ±ΡΠ°Ρ Π΄Π»Ρ Π²ΡΠ΅Ρ
ΡΠ΅Π°Π»ΠΈΠ·Π°ΡΠΈΠΉ
|
| 87 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 88 |
+
|
| 89 |
+
def render_tables(tables: Iterable[TableSchema], include_samples: bool = True) -> str:
|
| 90 |
+
"""Π‘ΠΎΠ±ΠΈΡΠ°Π΅Ρ ΡΠ΅ΠΊΡΡΠΎΠ²ΠΎΠ΅ ΠΏΡΠ΅Π΄ΡΡΠ°Π²Π»Π΅Π½ΠΈΠ΅ ΡΠΏΠΈΡΠΊΠ° ΡΠ°Π±Π»ΠΈΡ Π΄Π»Ρ ΠΏΡΠΎΠΌΠΏΡΠ°."""
|
| 91 |
+
parts: list[str] = []
|
| 92 |
+
for t in tables:
|
| 93 |
+
parts.append(t.to_ddl())
|
| 94 |
+
if include_samples and t.sample_rows:
|
| 95 |
+
parts.append(f"-- ΠΡΠΈΠΌΠ΅ΡΡ ΡΡΡΠΎΠΊ ΠΈΠ· {t.name}:")
|
| 96 |
+
for row in t.sample_rows:
|
| 97 |
+
parts.append(f"-- {row}")
|
| 98 |
+
parts.append("")
|
| 99 |
+
return "\n".join(parts).strip()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½β
|
| 103 |
+
# Π Π΅Π°Π»ΠΈΠ·Π°ΡΠΈΡ 1 β Spider/PAUQ-ΡΡΡΡΠΊΡΡΡΠ°
|
| 104 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
|
| 106 |
+
class SpiderSchemaProvider:
|
| 107 |
+
"""SchemaProvider Π΄Π»Ρ ΠΊΠ°ΡΠ°Π»ΠΎΠ³Π° ``data/databases/{db_id}/{db_id}.sqlite``.
|
| 108 |
+
|
| 109 |
+
ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ ΠΏΡΠΈ ΡΠ°Π±ΠΎΡΠ΅ Ρ PAUQ/Spider: ΠΊΠ°ΠΆΠ΄Π°Ρ ΠΠ Π»Π΅ΠΆΠΈΡ Π² ΠΎΠ΄Π½ΠΎΠΈΠΌΡΠ½Π½ΠΎΠΉ
|
| 110 |
+
ΠΏΠ°ΠΏΠΊΠ΅. ΠΠ΄ΠΈΠ½ ΡΠΊΠ·Π΅ΠΌΠΏΠ»ΡΡ SpiderSchemaProvider ΠΎΠ±ΡΠ»ΡΠΆΠΈΠ²Π°Π΅Ρ Π²ΡΡ ΠΊΠΎΠ»Π»Π΅ΠΊΡΠΈΡ
|
| 111 |
+
Π±Π°Π· β ΠΊΠΎΠ½ΠΊΡΠ΅ΡΠ½Π°Ρ ΠΠ Π²ΡΠ±ΠΈΡΠ°Π΅ΡΡΡ ΠΏΠΎ ``db_id`` Π² ΠΌΠ΅ΡΠΎΠ΄Π°Ρ
.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, databases_dir: Path | str):
|
| 115 |
+
self.databases_dir = Path(databases_dir)
|
| 116 |
+
|
| 117 |
+
def list_databases(self) -> list[str]:
|
| 118 |
+
if not self.databases_dir.exists():
|
| 119 |
+
return []
|
| 120 |
+
return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())
|
| 121 |
+
|
| 122 |
+
def db_path(self, db_id: str) -> Path:
|
| 123 |
+
path = self.databases_dir / db_id / f"{db_id}.sqlite"
|
| 124 |
+
if not path.exists():
|
| 125 |
+
raise FileNotFoundError(f"Database file not found: {path}")
|
| 126 |
+
return path
|
| 127 |
+
|
| 128 |
+
def for_database(self, db_id: str) -> "ConnectionSchemaProvider":
|
| 129 |
+
"""ΠΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ SchemaProvider, ΠΏΡΠΈΠ²ΡΠ·Π°Π½Π½ΡΠΉ ΠΊ ΠΊΠΎΠ½ΠΊΡΠ΅ΡΠ½ΠΎΠΉ ΠΠ."""
|
| 130 |
+
return ConnectionSchemaProvider(f"sqlite:///{self.db_path(db_id)}")
|
| 131 |
+
|
| 132 |
+
# ββ Π‘ΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠΎΡΡΡ Ρ ΠΏΡΠ΅Π΄ΡΠ΄ΡΡΠΈΠΌ API SchemaRetriever ββββββββββββββββ
|
| 133 |
+
|
| 134 |
+
def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableSchema]:
|
| 135 |
+
return self.for_database(db_id).get_tables(n_sample_rows=n_sample_rows)
|
| 136 |
+
|
| 137 |
+
def render_schema(self, db_id: str, include_samples: bool = True) -> str:
|
| 138 |
+
return self.for_database(db_id).render_schema(include_samples=include_samples)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
# Π Π΅Π°Π»ΠΈΠ·Π°ΡΠΈΡ 2 β ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½Π°Ρ ΠΠ ΠΏΠΎ connection string
|
| 143 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 144 |
+
|
| 145 |
+
class ConnectionSchemaProvider:
|
| 146 |
+
"""SchemaProvider Π΄Π»Ρ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ (SQLite/PostgreSQL/MySQL).
|
| 147 |
+
|
| 148 |
+
ΠΠ΅Π»Π΅Π³ΠΈΡΡΠ΅Ρ ΡΡΠ΅Π½ΠΈΠ΅ DbConnector'Ρ, Π½ΠΎ Π²ΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ ΠΎΠ±ΡΠ΅ΠΊΡΡ Π΅Π΄ΠΈΠ½ΠΎΠ³ΠΎ ΡΠΈΠΏΠ°
|
| 149 |
+
``TableSchema``. ΠΡΠΎ Π½ΡΠΆΠ½ΠΎ, ΡΡΠΎΠ±Ρ ΠΎΠ΄ΠΈΠ½ ΠΈ ΡΠΎΡ ΠΆΠ΅ ΠΊΠΎΠ΄ Π² API ΠΈ Streamlit
|
| 150 |
+
ΠΌΠΎΠ³ ΡΠ°Π±ΠΎΡΠ°ΡΡ ΠΊΠ°ΠΊ Ρ PAUQ-ΡΡΡΡΠΊΡΡΡΠΎΠΉ, ΡΠ°ΠΊ ΠΈ Ρ Π±ΠΎΠ΅Π²ΠΎΠΉ ΠΠ ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»Ρ.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(self, connection_string: str, n_sample_rows: int = 2):
|
| 154 |
+
# ΠΠΌΠΏΠΎΡΡ Π·Π΄Π΅ΡΡ, ΡΡΠΎΠ±Ρ ΠΈΠ·Π±Π΅ΠΆΠ°ΡΡ ΠΊΠΎΠ»ΡΡΠ΅Π²ΠΎΠΉ Π·Π°Π²ΠΈΡΠΈΠΌΠΎΡΡΠΈ
|
| 155 |
+
# (db.connector β data.schema_provider Π² ΡΠ»ΡΡΠ°Π΅ ΡΠ°ΡΠ°Π΄Π°).
|
| 156 |
+
from src.db.connector import DbConnector
|
| 157 |
+
self._connector = DbConnector(connection_string, n_sample_rows=n_sample_rows)
|
| 158 |
+
self.connection_string = self._connector.connection_string
|
| 159 |
+
|
| 160 |
+
# ββ ΠΠ°Π·ΠΎΠ²ΡΠ΅ ΠΎΠΏΠ΅ΡΠ°ΡΠΈΠΈ SchemaProvider βββββββββββββββββββββββββββββββ
|
| 161 |
+
|
| 162 |
+
def list_tables(self) -> list[str]:
|
| 163 |
+
return self._connector.list_tables()
|
| 164 |
+
|
| 165 |
+
def get_tables(self, n_sample_rows: int = 3) -> list[TableSchema]:
|
| 166 |
+
# DbConnector Π² ΡΠ΅ΠΊΡΡΠ΅ΠΉ ΡΠ΅Π°Π»ΠΈΠ·Π°ΡΠΈΠΈ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠ΅Ρ ΡΠ²ΠΎΠΉ n_sample_rows ΠΈΠ· ctor;
|
| 167 |
+
# Π΄Π»Ρ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠΎΡΡΠΈ Ρ ΠΏΡΠΎΡΠΎΠΊΠΎΠ»ΠΎΠΌ β ΠΈΠ³Π½ΠΎΡΠΈΡΡΠ΅ΠΌ ΠΏΠ°ΡΠ°ΠΌΠ΅ΡΡ Π·Π΄Π΅ΡΡ, Π΄ΠΎΠ²Π΅ΡΡΡ
|
| 168 |
+
# Π½Π°ΡΡΡΠΎΠΉΠΊΠ΅ ΠΊΠΎΠ½Π½Π΅ΠΊΡΠΎΡΠ°. ΠΡΠΈ ΠΆΠ΅Π»Π°Π½ΠΈΠΈ ΠΌΠΎΠΆΠ½ΠΎ Π·Π°Π²Π΅ΡΡΠΈ setter.
|
| 169 |
+
raw = self._connector.get_schema(include_samples=n_sample_rows > 0)
|
| 170 |
+
return [
|
| 171 |
+
TableSchema(
|
| 172 |
+
name=t.name,
|
| 173 |
+
columns=[
|
| 174 |
+
ColumnSchema(
|
| 175 |
+
name=c.name, type=c.type,
|
| 176 |
+
nullable=c.nullable, primary_key=c.primary_key,
|
| 177 |
+
)
|
| 178 |
+
for c in t.columns
|
| 179 |
+
],
|
| 180 |
+
sample_rows=list(t.sample_rows),
|
| 181 |
+
)
|
| 182 |
+
for t in raw
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
def render_schema(self, include_samples: bool = True) -> str:
|
| 186 |
+
return render_tables(self.get_tables(n_sample_rows=2 if include_samples else 0),
|
| 187 |
+
include_samples=include_samples)
|
| 188 |
+
|
| 189 |
+
def test_connection(self) -> bool:
|
| 190 |
+
return self._connector.test_connection()
|
src/db/connector.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
-
"""DbConnector
|
| 2 |
|
| 3 |
-
|
| 4 |
-
SQLite
|
| 5 |
-
PostgreSQL
|
| 6 |
-
MySQL
|
| 7 |
|
| 8 |
-
|
| 9 |
conn = DbConnector("sqlite:///data/demo/sales.sqlite")
|
| 10 |
print(conn.render_schema())
|
| 11 |
tables = conn.list_tables()
|
|
@@ -13,11 +13,14 @@ Primer:
|
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
|
|
|
| 16 |
import sqlite3
|
| 17 |
from dataclasses import dataclass, field
|
| 18 |
from pathlib import Path
|
| 19 |
from urllib.parse import urlparse
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
@dataclass
|
| 23 |
class ColumnInfo:
|
|
@@ -34,7 +37,7 @@ class TableInfo:
|
|
| 34 |
sample_rows: list[tuple] = field(default_factory=list)
|
| 35 |
|
| 36 |
def to_ddl(self) -> str:
|
| 37 |
-
"""
|
| 38 |
col_parts = []
|
| 39 |
for col in self.columns:
|
| 40 |
line = f" {col.name} {col.type}"
|
|
@@ -47,7 +50,7 @@ class TableInfo:
|
|
| 47 |
|
| 48 |
|
| 49 |
class DbConnector:
|
| 50 |
-
"""
|
| 51 |
|
| 52 |
def __init__(self, connection_string: str, n_sample_rows: int = 2):
|
| 53 |
self.connection_string = self._normalize(connection_string)
|
|
@@ -66,7 +69,7 @@ class DbConnector:
|
|
| 66 |
for t in tables:
|
| 67 |
parts.append(t.to_ddl())
|
| 68 |
if include_samples and t.sample_rows:
|
| 69 |
-
parts.append(f"--
|
| 70 |
for row in t.sample_rows:
|
| 71 |
parts.append(f"-- {row}")
|
| 72 |
parts.append("")
|
|
@@ -76,7 +79,8 @@ class DbConnector:
|
|
| 76 |
try:
|
| 77 |
self._get_tables(n_sample_rows=0)
|
| 78 |
return True
|
| 79 |
-
except Exception:
|
|
|
|
| 80 |
return False
|
| 81 |
|
| 82 |
def _get_tables(self, n_sample_rows: int) -> list[TableInfo]:
|
|
@@ -87,11 +91,19 @@ class DbConnector:
|
|
| 87 |
elif self._db_type == "mysql":
|
| 88 |
return self._get_tables_mysql(n_sample_rows)
|
| 89 |
else:
|
| 90 |
-
raise ValueError(f"
|
| 91 |
|
| 92 |
def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]:
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 96 |
try:
|
| 97 |
cur = conn.cursor()
|
|
@@ -118,8 +130,9 @@ class DbConnector:
|
|
| 118 |
try:
|
| 119 |
cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
|
| 120 |
samples = cur.fetchall()
|
| 121 |
-
except sqlite3.Error:
|
| 122 |
-
|
|
|
|
| 123 |
tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
|
| 124 |
return tables
|
| 125 |
finally:
|
|
@@ -129,7 +142,7 @@ class DbConnector:
|
|
| 129 |
try:
|
| 130 |
import psycopg2 # type: ignore
|
| 131 |
except ImportError as e:
|
| 132 |
-
raise ImportError("
|
| 133 |
|
| 134 |
conn = psycopg2.connect(self.connection_string)
|
| 135 |
try:
|
|
@@ -166,7 +179,7 @@ class DbConnector:
|
|
| 166 |
try:
|
| 167 |
import pymysql # type: ignore
|
| 168 |
except ImportError as e:
|
| 169 |
-
raise ImportError("
|
| 170 |
|
| 171 |
parsed = urlparse(self.connection_string)
|
| 172 |
conn = pymysql.connect(
|
|
@@ -207,22 +220,22 @@ class DbConnector:
|
|
| 207 |
return Path(cs)
|
| 208 |
|
| 209 |
@staticmethod
|
| 210 |
-
def
|
| 211 |
-
"""
|
| 212 |
-
|
| 213 |
-
import tempfile
|
| 214 |
-
journal = Path(str(path) + "-journal")
|
| 215 |
-
wal = Path(str(path) + "-wal")
|
| 216 |
-
if journal.exists() or wal.exists():
|
| 217 |
-
tmp = Path(tempfile.mktemp(suffix=".sqlite"))
|
| 218 |
-
shutil.copy2(path, tmp)
|
| 219 |
-
return tmp
|
| 220 |
-
return path
|
| 221 |
|
| 222 |
@staticmethod
|
| 223 |
def _normalize(cs: str) -> str:
|
| 224 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
cs = cs.strip()
|
|
|
|
|
|
|
| 226 |
if cs.endswith(".sqlite") or cs.endswith(".db"):
|
| 227 |
return f"sqlite:///{cs}"
|
| 228 |
return cs
|
|
@@ -235,4 +248,4 @@ class DbConnector:
|
|
| 235 |
return "postgresql"
|
| 236 |
if cs.startswith("mysql"):
|
| 237 |
return "mysql"
|
| 238 |
-
raise ValueError(f"
|
|
|
|
| 1 |
+
"""DbConnector β ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ ΠΊ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠΉ ΠΠ ΠΈ ΡΡΠ΅Π½ΠΈΠ΅ ΡΡ
Π΅ΠΌΡ.
|
| 2 |
|
| 3 |
+
ΠΠΎΠ΄Π΄Π΅ΡΠΆΠΈΠ²Π°Π΅ΠΌΡΠ΅ ΡΠΈΠΏΡ ΠΠ:
|
| 4 |
+
SQLite β ΠΏΡΡΡ ΠΊ ΡΠ°ΠΉΠ»Ρ: "sqlite:///path/to/db.sqlite" ΠΈΠ»ΠΈ ΠΏΡΠΎΡΡΠΎ ΠΏΡΡΡ
|
| 5 |
+
PostgreSQL β "postgresql://user:pass@host:port/dbname" (ΡΡΠ΅Π±ΡΠ΅Ρ psycopg2)
|
| 6 |
+
MySQL β "mysql://user:pass@host:port/dbname" (ΡΡΠ΅Π±ΡΠ΅Ρ pymysql)
|
| 7 |
|
| 8 |
+
ΠΡΠΈΠΌΠ΅Ρ:
|
| 9 |
conn = DbConnector("sqlite:///data/demo/sales.sqlite")
|
| 10 |
print(conn.render_schema())
|
| 11 |
tables = conn.list_tables()
|
|
|
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
+
import logging
|
| 17 |
import sqlite3
|
| 18 |
from dataclasses import dataclass, field
|
| 19 |
from pathlib import Path
|
| 20 |
from urllib.parse import urlparse
|
| 21 |
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
|
| 25 |
@dataclass
|
| 26 |
class ColumnInfo:
|
|
|
|
| 37 |
sample_rows: list[tuple] = field(default_factory=list)
|
| 38 |
|
| 39 |
def to_ddl(self) -> str:
|
| 40 |
+
"""ΠΠ΅Π½Π΅ΡΠΈΡΡΠ΅Ρ CREATE TABLE statement ΠΈΠ· ΠΌΠ΅ΡΠ°Π΄Π°Π½Π½ΡΡ
."""
|
| 41 |
col_parts = []
|
| 42 |
for col in self.columns:
|
| 43 |
line = f" {col.name} {col.type}"
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
class DbConnector:
|
| 53 |
+
"""Π£Π½ΠΈΠ²Π΅ΡΡΠ°Π»ΡΠ½ΡΠΉ ΠΊΠΎΠ½Π½Π΅ΠΊΡΠΎΡ ΠΊ ΠΠ. Π§ΠΈΡΠ°Π΅Ρ ΡΡ
Π΅ΠΌΡ Π΄Π»Ρ ΠΏΠΎΠ΄ΡΡΠ°Π½ΠΎΠ²ΠΊΠΈ Π² ΠΏΡΠΎΠΌΠΏΡ."""
|
| 54 |
|
| 55 |
def __init__(self, connection_string: str, n_sample_rows: int = 2):
|
| 56 |
self.connection_string = self._normalize(connection_string)
|
|
|
|
| 69 |
for t in tables:
|
| 70 |
parts.append(t.to_ddl())
|
| 71 |
if include_samples and t.sample_rows:
|
| 72 |
+
parts.append(f"-- ΠΡΠΈΠΌΠ΅ΡΡ ΡΡΡΠΎΠΊ ΠΈΠ· {t.name}:")
|
| 73 |
for row in t.sample_rows:
|
| 74 |
parts.append(f"-- {row}")
|
| 75 |
parts.append("")
|
|
|
|
| 79 |
try:
|
| 80 |
self._get_tables(n_sample_rows=0)
|
| 81 |
return True
|
| 82 |
+
except Exception as e: # noqa: BLE001
|
| 83 |
+
logger.warning("ΠΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ ΠΊ ΠΠ Π½Π΅ ΡΠ΄Π°Π»ΠΎΡΡ: %s", e)
|
| 84 |
return False
|
| 85 |
|
| 86 |
def _get_tables(self, n_sample_rows: int) -> list[TableInfo]:
|
|
|
|
| 91 |
elif self._db_type == "mysql":
|
| 92 |
return self._get_tables_mysql(n_sample_rows)
|
| 93 |
else:
|
| 94 |
+
raise ValueError(f"ΠΠ΅ΠΈΠ·Π²Π΅ΡΡΠ½ΡΠΉ ΡΠΈΠΏ ΠΠ: {self._db_type}")
|
| 95 |
|
| 96 |
def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]:
|
| 97 |
+
"""SQLite-ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ Π² ΡΠ΅ΠΆΠΈΠΌΠ΅ read-only ΡΠ΅ΡΠ΅Π· URI.
|
| 98 |
+
|
| 99 |
+
immutable=1 Π³ΠΎΠ²ΠΎΡΠΈΡ SQLite, ΡΡΠΎ ΡΠ°ΠΉΠ» Π½Π΅ ΠΈΠ·ΠΌΠ΅Π½ΡΠ΅ΡΡΡ Π²ΠΎ Π²ΡΠ΅ΠΌΡ ΡΠ΅ΡΡΠΈΠΈ,
|
| 100 |
+
ΠΏΠΎΡΡΠΎΠΌΡ journal/WAL-ΡΠ°ΠΉΠ»Ρ ΠΌΠΎΠΆΠ½ΠΎ ΠΈΠ³Π½ΠΎΡΠΈΡΠΎΠ²Π°ΡΡ. ΠΡΠΎ ΡΠ±ΠΈΡΠ°Π΅Ρ ΠΏΡΠ΅ΠΆΠ½ΡΡ
|
| 101 |
+
Π»ΠΎΠ³ΠΈΠΊΡ Ρ ΠΊΠΎΠΏΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ΠΌ ΠΠ Π²ΠΎ Π²ΡΠ΅ΠΌΠ΅Π½Π½ΡΡ Π΄ΠΈΡΠ΅ΠΊΡΠΎΡΠΈΡ ΠΈ Π·Π°ΠΎΠ΄Π½ΠΎ Π΄Π°ΡΡ
|
| 102 |
+
guardrail-ΡΡΠΎΠ²Π΅Π½Ρ Π±Π΅Π·ΠΎΠΏΠ°ΡΠ½ΠΎΡΡΠΈ: Π»ΡΠ±Π°Ρ ΠΌΠΎΠ΄ΠΈΡΠΈΡΠΈΡΡΡΡΠ°Ρ ΠΎΠΏΠ΅ΡΠ°ΡΠΈΡ
|
| 103 |
+
Π½Π° ΡΠ°ΠΊΠΎΠΌ ΡΠΎΠ΅Π΄ΠΈΠ½Π΅Π½ΠΈΠΈ Π·Π°Π²Π΅ΡΡΠΈΡΡΡ ΠΎΡΠΈΠ±ΠΊΠΎΠΉ.
|
| 104 |
+
"""
|
| 105 |
+
path = self._sqlite_path()
|
| 106 |
+
conn = sqlite3.connect(self._sqlite_uri(path), uri=True)
|
| 107 |
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 108 |
try:
|
| 109 |
cur = conn.cursor()
|
|
|
|
| 130 |
try:
|
| 131 |
cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
|
| 132 |
samples = cur.fetchall()
|
| 133 |
+
except sqlite3.Error as e:
|
| 134 |
+
logger.debug("ΠΠ΅ ΡΠ΄Π°Π»ΠΎΡΡ ΠΏΠΎΠ»ΡΡΠΈΡΡ sample-ΡΡΡΠΎΠΊΠΈ Π΄Π»Ρ %s: %s",
|
| 135 |
+
name, e)
|
| 136 |
tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
|
| 137 |
return tables
|
| 138 |
finally:
|
|
|
|
| 142 |
try:
|
| 143 |
import psycopg2 # type: ignore
|
| 144 |
except ImportError as e:
|
| 145 |
+
raise ImportError("Π£ΡΡΠ°Π½ΠΎΠ²ΠΈ psycopg2: pip install psycopg2-binary") from e
|
| 146 |
|
| 147 |
conn = psycopg2.connect(self.connection_string)
|
| 148 |
try:
|
|
|
|
| 179 |
try:
|
| 180 |
import pymysql # type: ignore
|
| 181 |
except ImportError as e:
|
| 182 |
+
raise ImportError("Π£ΡΡΠ°Π½ΠΎΠ²ΠΈ pymysql: pip install pymysql") from e
|
| 183 |
|
| 184 |
parsed = urlparse(self.connection_string)
|
| 185 |
conn = pymysql.connect(
|
|
|
|
| 220 |
return Path(cs)
|
| 221 |
|
| 222 |
@staticmethod
|
| 223 |
+
def _sqlite_uri(path: Path) -> str:
|
| 224 |
+
"""Read-only URI Π΄Π»Ρ SQLite Ρ ΠΈΠ³Π½ΠΎΡΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ΠΌ journal/WAL."""
|
| 225 |
+
return f"file:{path}?mode=ro&immutable=1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
@staticmethod
|
| 228 |
def _normalize(cs: str) -> str:
|
| 229 |
+
"""ΠΡΠ»ΠΈ ΠΏΠ΅ΡΠ΅Π΄Π°Π½ ΠΏΡΠΎΡΡΠΎ ΠΏΡΡΡ ΠΊ ΡΠ°ΠΉΠ»Ρ β ΠΏΡΠ΅Π²ΡΠ°ΡΠ°Π΅ΠΌ Π² sqlite:// URI.
|
| 230 |
+
|
| 231 |
+
ΠΡΠ»ΠΈ ΡΡΡΠΎΠΊΠ° ΡΠΆΠ΅ Π²ΡΠ³Π»ΡΠ΄ΠΈΡ ΠΊΠ°ΠΊ URI (sqlite/postgres/mysql) β
|
| 232 |
+
Π²ΠΎΠ·Π²ΡΠ°ΡΠ°Π΅ΠΌ ΠΊΠ°ΠΊ Π΅ΡΡΡ. ΠΠ΅Π· ΡΡΠΎΠΉ ΠΏΡΠΎΠ²Π΅ΡΠΊΠΈ ΡΡΠ΅Π½Π°ΡΠΈΠΉ Β«ΠΏΠ΅ΡΠ΅Π΄Π°Π»ΠΈ
|
| 233 |
+
ΠΊΠΎΡΡΠ΅ΠΊΡΠ½ΡΠΉ sqlite:///pathΒ» ΠΏΡΠΈΠ²ΠΎΠ΄ΠΈΠ» ΠΊ Π΄Π²ΠΎΠΉΠ½ΠΎΠΉ Π½ΠΎΡΠΌΠ°Π»ΠΈΠ·Π°ΡΠΈΠΈ
|
| 234 |
+
ΠΈ ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ ΠΊ Π½Π΅ΡΡΡΠ΅ΡΡΠ²ΡΡΡΠ΅ΠΌΡ ΠΏΡΡΠΈ.
|
| 235 |
+
"""
|
| 236 |
cs = cs.strip()
|
| 237 |
+
if cs.startswith(("sqlite:", "postgres", "mysql")):
|
| 238 |
+
return cs
|
| 239 |
if cs.endswith(".sqlite") or cs.endswith(".db"):
|
| 240 |
return f"sqlite:///{cs}"
|
| 241 |
return cs
|
|
|
|
| 248 |
return "postgresql"
|
| 249 |
if cs.startswith("mysql"):
|
| 250 |
return "mysql"
|
| 251 |
+
raise ValueError(f"ΠΠ΅ ΡΠ΄Π°Π»ΠΎΡΡ ΠΎΠΏΡΠ΅Π΄Π΅Π»ΠΈΡΡ ΡΠΈΠΏ ΠΠ: {cs}")
|
src/db/executor.py
CHANGED
|
@@ -1,23 +1,25 @@
|
|
| 1 |
-
"""SqlExecutor
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
print(result.rows)
|
| 8 |
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
|
|
|
| 12 |
import sqlite3
|
| 13 |
-
from dataclasses import dataclass
|
| 14 |
from pathlib import Path
|
| 15 |
from urllib.parse import urlparse
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
| 18 |
@dataclass
|
| 19 |
class QueryResult:
|
| 20 |
-
"""
|
| 21 |
columns: list[str]
|
| 22 |
rows: list[list]
|
| 23 |
row_count: int
|
|
@@ -39,9 +41,9 @@ class QueryResult:
|
|
| 39 |
|
| 40 |
def to_markdown_table(self) -> str:
|
| 41 |
if self.error:
|
| 42 |
-
return f"
|
| 43 |
if not self.rows:
|
| 44 |
-
return "(
|
| 45 |
header = " | ".join(self.columns)
|
| 46 |
sep = " | ".join(["---"] * len(self.columns))
|
| 47 |
rows = "\n".join(" | ".join(str(v) for v in row) for row in self.rows)
|
|
@@ -49,7 +51,7 @@ class QueryResult:
|
|
| 49 |
|
| 50 |
|
| 51 |
class SqlExecutor:
|
| 52 |
-
"""
|
| 53 |
|
| 54 |
MAX_ROWS = 500
|
| 55 |
|
|
@@ -67,13 +69,14 @@ class SqlExecutor:
|
|
| 67 |
return self._run_mysql(sql)
|
| 68 |
else:
|
| 69 |
return QueryResult(columns=[], rows=[], row_count=0, sql=sql,
|
| 70 |
-
error=f"
|
| 71 |
-
except Exception as e:
|
|
|
|
| 72 |
return QueryResult(columns=[], rows=[], row_count=0, sql=sql, error=str(e))
|
| 73 |
|
| 74 |
def _run_sqlite(self, sql: str) -> QueryResult:
|
| 75 |
-
path = self.
|
| 76 |
-
conn = sqlite3.connect(
|
| 77 |
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 78 |
try:
|
| 79 |
cur = conn.cursor()
|
|
@@ -88,10 +91,12 @@ class SqlExecutor:
|
|
| 88 |
try:
|
| 89 |
import psycopg2 # type: ignore
|
| 90 |
except ImportError as e:
|
| 91 |
-
raise ImportError("
|
| 92 |
|
| 93 |
conn = psycopg2.connect(self.connection_string)
|
| 94 |
try:
|
|
|
|
|
|
|
| 95 |
cur = conn.cursor()
|
| 96 |
cur.execute(sql)
|
| 97 |
cols = [d[0] for d in (cur.description or [])]
|
|
@@ -104,7 +109,7 @@ class SqlExecutor:
|
|
| 104 |
try:
|
| 105 |
import pymysql # type: ignore
|
| 106 |
except ImportError as e:
|
| 107 |
-
raise ImportError("
|
| 108 |
|
| 109 |
parsed = urlparse(self.connection_string)
|
| 110 |
conn = pymysql.connect(
|
|
@@ -116,6 +121,9 @@ class SqlExecutor:
|
|
| 116 |
)
|
| 117 |
try:
|
| 118 |
cur = conn.cursor()
|
|
|
|
|
|
|
|
|
|
| 119 |
cur.execute(sql)
|
| 120 |
cols = [d[0] for d in (cur.description or [])]
|
| 121 |
rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
|
|
@@ -130,16 +138,9 @@ class SqlExecutor:
|
|
| 130 |
return Path(cs)
|
| 131 |
|
| 132 |
@staticmethod
|
| 133 |
-
def
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
journal = Path(str(path) + "-journal")
|
| 137 |
-
wal = Path(str(path) + "-wal")
|
| 138 |
-
if journal.exists() or wal.exists():
|
| 139 |
-
tmp = Path(tempfile.mktemp(suffix=".sqlite"))
|
| 140 |
-
shutil.copy2(path, tmp)
|
| 141 |
-
return tmp
|
| 142 |
-
return path
|
| 143 |
|
| 144 |
@staticmethod
|
| 145 |
def _detect_type(cs: str) -> str:
|
|
@@ -149,4 +150,4 @@ class SqlExecutor:
|
|
| 149 |
return "postgresql"
|
| 150 |
if cs.startswith("mysql"):
|
| 151 |
return "mysql"
|
| 152 |
-
raise ValueError(f"
|
|
|
|
| 1 |
+
"""SqlExecutor β Π²ΡΠΏΠΎΠ»Π½ΡΠ΅Ρ SQL-Π·Π°ΠΏΡΠΎΡ Π½Π° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΡΠ½Π½ΠΎΠΉ ΠΠ.
|
| 2 |
|
| 3 |
+
ΠΠ»Ρ SQLite ΡΠΎΠ΅Π΄ΠΈΠ½Π΅Π½ΠΈΠ΅ ΠΎΡΠΊΡΡΠ²Π°Π΅ΡΡΡ ΡΠ΅ΡΠ΅Π· URI Ρ ``mode=ro&immutable=1`` β
|
| 4 |
+
ΡΡΠΎ ΠΎΠ±Π΅ΡΠΏΠ΅ΡΠΈΠ²Π°Π΅Ρ read-only Π±Π΅Π· ΠΊΠΎΠΏΠΈΡΠΎΠ²Π°Π½ΠΈΡ ΡΠ°ΠΉΠ»Π° ΠΈ ΡΠ΅ΠΆΠ΅Ρ Π»ΡΠ±ΡΠ΅ ΠΏΠΎΠΏΡΡΠΊΠΈ
|
| 5 |
+
Π²ΡΠΏΠΎΠ»Π½ΠΈΡΡ DDL/DML Π½Π° ΡΡΠΎΠ²Π½Π΅ Π΄ΡΠ°ΠΉΠ²Π΅ΡΠ°. ΠΠ»Ρ PostgreSQL/MySQL ΠΎΡΠ΄Π΅Π»ΡΠ½ΡΠΉ
|
| 6 |
+
guardrail ΠΎΡΡΠ°ΡΡΡΡ Π½Π° ΡΡΠΎΡΠΎΠ½Π΅ API (ΡΠΌ. is_select_only Π² postprocess.py).
|
|
|
|
| 7 |
"""
|
| 8 |
|
| 9 |
from __future__ import annotations
|
| 10 |
|
| 11 |
+
import logging
|
| 12 |
import sqlite3
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
from pathlib import Path
|
| 15 |
from urllib.parse import urlparse
|
| 16 |
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
|
| 20 |
@dataclass
|
| 21 |
class QueryResult:
|
| 22 |
+
"""Π Π΅Π·ΡΠ»ΡΡΠ°Ρ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΡ SQL-Π·Π°ΠΏΡΠΎΡΠ°."""
|
| 23 |
columns: list[str]
|
| 24 |
rows: list[list]
|
| 25 |
row_count: int
|
|
|
|
| 41 |
|
| 42 |
def to_markdown_table(self) -> str:
|
| 43 |
if self.error:
|
| 44 |
+
return f"ΠΡΠΈΠ±ΠΊΠ°: {self.error}"
|
| 45 |
if not self.rows:
|
| 46 |
+
return "(ΠΏΡΡΡΠΎΠΉ ΡΠ΅Π·ΡΠ»ΡΡΠ°Ρ)"
|
| 47 |
header = " | ".join(self.columns)
|
| 48 |
sep = " | ".join(["---"] * len(self.columns))
|
| 49 |
rows = "\n".join(" | ".join(str(v) for v in row) for row in self.rows)
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
class SqlExecutor:
|
| 54 |
+
"""ΠΡΠΏΠΎΠ»Π½ΡΠ΅Ρ SQL Π½Π° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΡΠ½Π½ΠΎΠΉ ΠΠ."""
|
| 55 |
|
| 56 |
MAX_ROWS = 500
|
| 57 |
|
|
|
|
| 69 |
return self._run_mysql(sql)
|
| 70 |
else:
|
| 71 |
return QueryResult(columns=[], rows=[], row_count=0, sql=sql,
|
| 72 |
+
error=f"ΠΠ΅ΠΈΠ·Π²Π΅ΡΡΠ½ΡΠΉ ΡΠΈΠΏ ΠΠ: {self._db_type}")
|
| 73 |
+
except Exception as e: # noqa: BLE001
|
| 74 |
+
logger.warning("ΠΡΠΈΠ±ΠΊΠ° Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΡ SQL: %s", e)
|
| 75 |
return QueryResult(columns=[], rows=[], row_count=0, sql=sql, error=str(e))
|
| 76 |
|
| 77 |
def _run_sqlite(self, sql: str) -> QueryResult:
|
| 78 |
+
path = self._sqlite_path()
|
| 79 |
+
conn = sqlite3.connect(self._sqlite_uri(path), uri=True)
|
| 80 |
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 81 |
try:
|
| 82 |
cur = conn.cursor()
|
|
|
|
| 91 |
try:
|
| 92 |
import psycopg2 # type: ignore
|
| 93 |
except ImportError as e:
|
| 94 |
+
raise ImportError("Π£ΡΡΠ°Π½ΠΎΠ²ΠΈ psycopg2: pip install psycopg2-binary") from e
|
| 95 |
|
| 96 |
conn = psycopg2.connect(self.connection_string)
|
| 97 |
try:
|
| 98 |
+
# Π’ΡΠ°Π½Π·Π°ΠΊΡΠΈΡ Π² ΡΠ΅ΠΆΠΈΠΌΠ΅ READ ONLY β guardrail Π΄ΡΠ°ΠΉΠ²Π΅ΡΠ½ΠΎΠ³ΠΎ ΡΡΠΎΠ²Π½Ρ.
|
| 99 |
+
conn.set_session(readonly=True, autocommit=False)
|
| 100 |
cur = conn.cursor()
|
| 101 |
cur.execute(sql)
|
| 102 |
cols = [d[0] for d in (cur.description or [])]
|
|
|
|
| 109 |
try:
|
| 110 |
import pymysql # type: ignore
|
| 111 |
except ImportError as e:
|
| 112 |
+
raise ImportError("Π£ΡΡΠ°Π½ΠΎΠ²ΠΈ pymysql: pip install pymysql") from e
|
| 113 |
|
| 114 |
parsed = urlparse(self.connection_string)
|
| 115 |
conn = pymysql.connect(
|
|
|
|
| 121 |
)
|
| 122 |
try:
|
| 123 |
cur = conn.cursor()
|
| 124 |
+
# MySQL Π½Π΅ ΠΈΠΌΠ΅Π΅Ρ Β«Π³Π»ΠΎΠ±Π°Π»ΡΠ½ΠΎΠ³ΠΎΒ» read-only ΡΠ»Π°Π³Π° Π² Π΄ΡΠ°ΠΉΠ²Π΅ΡΠ΅,
|
| 125 |
+
# Π½ΠΎ ΠΌΡ ΠΌΠΎΠΆΠ΅ΠΌ ΡΡΠ°ΡΡΠΎΠ²Π°ΡΡ read-only-ΡΡΠ°Π½Π·Π°ΠΊΡΠΈΡ.
|
| 126 |
+
cur.execute("START TRANSACTION READ ONLY")
|
| 127 |
cur.execute(sql)
|
| 128 |
cols = [d[0] for d in (cur.description or [])]
|
| 129 |
rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
|
|
|
|
| 138 |
return Path(cs)
|
| 139 |
|
| 140 |
@staticmethod
|
| 141 |
+
def _sqlite_uri(path: Path) -> str:
|
| 142 |
+
"""Read-only URI Π΄Π»Ρ SQLite Ρ ΠΈοΏ½οΏ½Π½ΠΎΡΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ΠΌ journal/WAL."""
|
| 143 |
+
return f"file:{path}?mode=ro&immutable=1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
@staticmethod
|
| 146 |
def _detect_type(cs: str) -> str:
|
|
|
|
| 150 |
return "postgresql"
|
| 151 |
if cs.startswith("mysql"):
|
| 152 |
return "mysql"
|
| 153 |
+
raise ValueError(f"ΠΠ΅ ΡΠ΄Π°Π»ΠΎΡΡ ΠΎΠΏΡΠ΅Π΄Π΅Π»ΠΈΡΡ ΡΠΈΠΏ ΠΠ: {cs}")
|
src/models/inference.py
CHANGED
|
@@ -1,21 +1,25 @@
|
|
| 1 |
"""ΠΠ°Π³ΡΡΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»ΠΈ + LoRA-Π°Π΄Π°ΠΏΡΠ΅ΡΠ° ΠΈ ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ.
|
| 2 |
|
| 3 |
-
ΠΠ° Π΄Π΅ΡΠΊΡΠΎΠΏΠ΅/Π½ΠΎΡΡΠ±ΡΠΊΠ΅ Π±Π΅Π· GPU ΡΠ°Π±ΠΎΡΠ°Π΅Ρ Π½Π° CPU. ΠΠ΅Π΄Π»Π΅Π½Π½ΠΎ, Π½ΠΎ Π΄ΠΎΡΡΠ°ΡΠΎΡΠ½ΠΎ Π΄Π»Ρ
|
| 4 |
-
ΠΠ° Kaggle/Colab β Π½Π° GPU, Π±ΡΡΡΡΠ΅Π΅.
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
|
|
|
| 9 |
from dataclasses import dataclass
|
| 10 |
from pathlib import Path
|
| 11 |
|
| 12 |
import torch
|
| 13 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
|
|
|
|
| 15 |
from src.config import settings
|
| 16 |
from src.data.prompt import build_chat_messages
|
| 17 |
from src.models.postprocess import postprocess
|
| 18 |
|
|
|
|
|
|
|
| 19 |
|
| 20 |
@dataclass
|
| 21 |
class GenerationResult:
|
|
@@ -39,11 +43,18 @@ class InferenceEngine:
|
|
| 39 |
self.model = None
|
| 40 |
self._loaded = False
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def load(self) -> None:
|
| 43 |
"""ΠΠ΅Π½ΠΈΠ²ΠΎ Π³ΡΡΠ·ΠΈΠΌ ΠΌΠΎΠ΄Π΅Π»Ρ. ΠΠ° CPU Π±Π΅Π· ΠΊΠ²Π°Π½ΡΠΈΠ·Π°ΡΠΈΠΈ."""
|
| 44 |
if self._loaded:
|
| 45 |
return
|
| 46 |
|
|
|
|
|
|
|
| 47 |
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
|
| 48 |
# bfloat16 Π²Π΄Π²ΠΎΠ΅ ΠΌΠ΅Π½ΡΡΠ΅ float32 (~6 ΠΠ vs ~12 ΠΠ) ΠΈ ΠΏΠΎΠ΄Π΄Π΅ΡΠΆΠΈΠ²Π°Π΅ΡΡΡ Π½Π° CPU
|
| 49 |
self.model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -57,38 +68,53 @@ class InferenceEngine:
|
|
| 57 |
adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
|
| 58 |
try:
|
| 59 |
from peft import PeftModel
|
|
|
|
| 60 |
self.model = PeftModel.from_pretrained(self.model, adapter_id)
|
| 61 |
except ImportError:
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
self.model.eval()
|
| 65 |
self._loaded = True
|
|
|
|
| 66 |
|
| 67 |
def generate(
|
| 68 |
self,
|
| 69 |
schema: str,
|
| 70 |
question: str,
|
|
|
|
| 71 |
max_new_tokens: int | None = None,
|
| 72 |
) -> GenerationResult:
|
| 73 |
-
"""ΠΡΠΈΠ½ΠΈΠΌΠ°Π΅Ρ schema (ΡΠ΅ΠΊΡΡ DDL) ΠΈ Π²ΠΎΠΏΡΠΎΡ, Π²ΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ SQL.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
if not self._loaded:
|
| 75 |
self.load()
|
| 76 |
|
| 77 |
-
messages = build_chat_messages(schema, question)
|
| 78 |
prompt = self.tokenizer.apply_chat_template(
|
| 79 |
messages, tokenize=False, add_generation_prompt=True
|
| 80 |
)
|
| 81 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
with torch.no_grad():
|
| 84 |
-
output_ids = self.model.generate(
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
do_sample=settings.do_sample,
|
| 88 |
-
temperature=settings.temperature if settings.do_sample else 1.0,
|
| 89 |
-
pad_token_id=self.tokenizer.eos_token_id,
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
new_tokens = output_ids[0][inputs["input_ids"].shape[1] :]
|
| 93 |
raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 94 |
return GenerationResult(sql=postprocess(raw), raw_output=raw)
|
|
|
|
| 1 |
"""ΠΠ°Π³ΡΡΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»ΠΈ + LoRA-Π°Π΄Π°ΠΏΡΠ΅ΡΠ° ΠΈ ΠΈΠ½ΡΠ΅ΡΠ΅Π½Ρ.
|
| 2 |
|
| 3 |
+
ΠΠ° Π΄Π΅ΡΠΊΡΠΎΠΏΠ΅/Π½ΠΎΡΡΠ±ΡΠΊΠ΅ Π±Π΅Π· GPU ΡΠ°Π±ΠΎΡΠ°Π΅Ρ Π½Π° CPU. ΠΠ΅Π΄Π»Π΅Π½Π½ΠΎ, Π½ΠΎ Π΄ΠΎΡΡΠ°ΡΠΎΡΠ½ΠΎ Π΄Π»Ρ
|
| 4 |
+
ΡΠ°Π·ΡΠ°Π±ΠΎΡΠΊΠΈ ΠΈ Π΄Π΅ΠΌΠΎ. ΠΠ° Kaggle/Colab β Π½Π° GPU, Π±ΡΡΡΡΠ΅Π΅.
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
+
import logging
|
| 10 |
from dataclasses import dataclass
|
| 11 |
from pathlib import Path
|
| 12 |
|
| 13 |
import torch
|
| 14 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
|
| 16 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 17 |
from src.config import settings
|
| 18 |
from src.data.prompt import build_chat_messages
|
| 19 |
from src.models.postprocess import postprocess
|
| 20 |
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
|
| 24 |
@dataclass
|
| 25 |
class GenerationResult:
|
|
|
|
| 43 |
self.model = None
|
| 44 |
self._loaded = False
|
| 45 |
|
| 46 |
+
@property
|
| 47 |
+
def loaded(self) -> bool:
|
| 48 |
+
"""ΠΡΠ±Π»ΠΈΡΠ½ΠΎΠ΅ ΡΠ²ΠΎΠΉΡΡΠ²ΠΎ β ΡΡΠ°ΡΡΡ Π·Π°Π³ΡΡΠ·ΠΊΠΈ ΠΌΠΎΠ΄Π΅Π»ΠΈ."""
|
| 49 |
+
return self._loaded
|
| 50 |
+
|
| 51 |
def load(self) -> None:
|
| 52 |
"""ΠΠ΅Π½ΠΈΠ²ΠΎ Π³ΡΡΠ·ΠΈΠΌ ΠΌΠΎΠ΄Π΅Π»Ρ. ΠΠ° CPU Π±Π΅Π· ΠΊΠ²Π°Π½ΡΠΈΠ·Π°ΡΠΈΠΈ."""
|
| 53 |
if self._loaded:
|
| 54 |
return
|
| 55 |
|
| 56 |
+
logger.info("ΠΠ°Π³ΡΡΠ·ΠΊΠ° Π±Π°Π·ΠΎΠ²ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ %s Π½Π° ΡΡΡΡΠΎΠΉΡΡΠ²ΠΎ %s",
|
| 57 |
+
self.base_model_name, self.device)
|
| 58 |
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
|
| 59 |
# bfloat16 Π²Π΄Π²ΠΎΠ΅ ΠΌΠ΅Π½ΡΡΠ΅ float32 (~6 ΠΠ vs ~12 ΠΠ) ΠΈ ΠΏΠΎΠ΄Π΄Π΅ΡΠΆΠΈΠ²Π°Π΅ΡΡΡ Π½Π° CPU
|
| 60 |
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 68 |
adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
|
| 69 |
try:
|
| 70 |
from peft import PeftModel
|
| 71 |
+
logger.info("ΠΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ LoRA-Π°Π΄Π°ΠΏΡΠ΅ΡΠ° %s", adapter_id)
|
| 72 |
self.model = PeftModel.from_pretrained(self.model, adapter_id)
|
| 73 |
except ImportError:
|
| 74 |
+
logger.warning("peft Π½Π΅ ΡΡΡΠ°Π½ΠΎΠ²Π»Π΅Π½, ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ Π±Π°Π·ΠΎΠ²Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ Π±Π΅Π· LoRA")
|
| 75 |
+
except Exception as e: # noqa: BLE001 β Π»ΠΎΠ³ Π΄ΠΎΡΡΠ°ΡΠΎΡΠ΅Π½, Π±Π΅Π· ΠΏΠ°Π΄Π΅Π½ΠΈΡ
|
| 76 |
+
logger.warning("ΠΠ΅ ΡΠ΄Π°Π»ΠΎΡΡ ΠΏΠΎΠ΄Π³ΡΡΠ·ΠΈΡΡ LoRA-Π°Π΄Π°ΠΏΡΠ΅Ρ %s: %s",
|
| 77 |
+
adapter_id, e)
|
| 78 |
|
| 79 |
self.model.eval()
|
| 80 |
self._loaded = True
|
| 81 |
+
logger.info("InferenceEngine Π³ΠΎΡΠΎΠ² ΠΊ ΡΠ°Π±ΠΎΡΠ΅")
|
| 82 |
|
| 83 |
def generate(
|
| 84 |
self,
|
| 85 |
schema: str,
|
| 86 |
question: str,
|
| 87 |
+
vocabulary: BusinessVocabulary | None = None,
|
| 88 |
max_new_tokens: int | None = None,
|
| 89 |
) -> GenerationResult:
|
| 90 |
+
"""ΠΡΠΈΠ½ΠΈΠΌΠ°Π΅Ρ schema (ΡΠ΅ΠΊΡΡ DDL) ΠΈ Π²ΠΎΠΏΡΠΎΡ, Π²ΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ SQL.
|
| 91 |
+
|
| 92 |
+
ΠΡΠ»ΠΈ ΠΏΠ΅ΡΠ΅Π΄Π°Π½ Π½Π΅ΠΏΡΡΡΠΎΠΉ ``vocabulary``, Π±ΠΈΠ·Π½Π΅Ρ-ΡΠ΅ΡΠΌΠΈΠ½Ρ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ
|
| 93 |
+
ΠΏΠΎΠ΄ΠΌΠ΅ΡΠΈΠ²Π°ΡΡΡΡ Π² ΡΠΈΡΡΠ΅ΠΌΠ½ΠΎΠ΅ ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠ΅ ΡΠ΅ΡΠ΅Π· PromptBuilder.
|
| 94 |
+
ΠΡΠΎ ΡΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΠ΅Ρ ΡΠ°Π·Π΄Π΅Π»Ρ 3.6 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ.
|
| 95 |
+
"""
|
| 96 |
if not self._loaded:
|
| 97 |
self.load()
|
| 98 |
|
| 99 |
+
messages = build_chat_messages(schema, question, vocabulary=vocabulary)
|
| 100 |
prompt = self.tokenizer.apply_chat_template(
|
| 101 |
messages, tokenize=False, add_generation_prompt=True
|
| 102 |
)
|
| 103 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 104 |
|
| 105 |
+
# ΠΠ°ΡΠ°ΠΌΠ΅ΡΡΡ ΡΡΠΌΠΏΠ»ΠΈΠ½Π³Π°. ΠΡΠΈ do_sample=False temperature ΠΈΠ³Π½ΠΎΡΠΈΡΡΠ΅ΡΡΡ,
|
| 106 |
+
# ΠΏΠΎΡΡΠΎΠΌΡ Π½Π΅ ΠΏΠ΅ΡΠ΅Π΄Π°ΡΠΌ Π΅Ρ β ΠΈΠ½Π°ΡΠ΅ transformers Π²ΡΠ²ΠΎΠ΄ΠΈΡ warning.
|
| 107 |
+
gen_kwargs = {
|
| 108 |
+
"max_new_tokens": max_new_tokens or settings.max_new_tokens,
|
| 109 |
+
"do_sample": settings.do_sample,
|
| 110 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
| 111 |
+
}
|
| 112 |
+
if settings.do_sample:
|
| 113 |
+
gen_kwargs["temperature"] = settings.temperature
|
| 114 |
+
|
| 115 |
with torch.no_grad():
|
| 116 |
+
output_ids = self.model.generate(**inputs, **gen_kwargs)
|
| 117 |
+
|
| 118 |
+
new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 120 |
return GenerationResult(sql=postprocess(raw), raw_output=raw)
|
src/models/postprocess.py
CHANGED
|
@@ -1,24 +1,69 @@
|
|
| 1 |
-
"""ΠΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠ° SQL: ΡΠΈΡΡΠΊΠ° Π²ΡΠ²ΠΎΠ΄Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 5 |
import re
|
| 6 |
|
| 7 |
import sqlglot
|
|
|
|
| 8 |
from sqlglot.errors import ParseError
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def strip_model_artifacts(text: str) -> str:
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
# ΠΡΠ»ΠΈ Π΅ΡΡΡ Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΎ SQL β Π±Π΅ΡΡΠΌ ΠΏΠ΅ΡΠ²ΡΠΉ Π΄ΠΎ ΡΠΎΡΠΊΠΈ Ρ Π·Π°ΠΏΡΡΠΎΠΉ
|
| 22 |
text = text.strip()
|
| 23 |
if ";" in text:
|
| 24 |
head, _, _ = text.partition(";")
|
|
@@ -28,23 +73,75 @@ def strip_model_artifacts(text: str) -> str:
|
|
| 28 |
|
| 29 |
|
| 30 |
def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool:
|
| 31 |
-
"""Π
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
try:
|
| 33 |
-
sqlglot.parse_one(sql, dialect=dialect)
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def normalize_sql(sql: str, dialect: str = "sqlite") -> str:
|
| 40 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
try:
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
return re.sub(r"\s+", " ", sql.
|
| 46 |
|
| 47 |
|
| 48 |
def postprocess(raw_output: str) -> str:
|
| 49 |
-
"""ΠΠΎΠ»Π½ΡΠΉ pipeline ΠΏΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠΈ.
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ΠΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠ° SQL: ΡΠΈΡΡΠΊΠ° Π²ΡΠ²ΠΎΠ΄Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ, Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΡ ΠΈ Π½ΠΎΡΠΌΠ°Π»ΠΈΠ·Π°ΡΠΈΡ.
|
| 2 |
+
|
| 3 |
+
Π‘ΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΠ΅Ρ ΡΠ°Π·Π΄Π΅Π»Ρ 2.5 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ. Pipeline:
|
| 4 |
+
raw_output βββΊ strip_model_artifacts βββΊ is_valid_sql βββΊ sql | ""
|
| 5 |
+
|
| 6 |
+
ΠΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΠΎ ΠΌΠΎΠ΄ΡΠ»Ρ ΠΏΡΠ΅Π΄ΠΎΡΡΠ°Π²Π»ΡΠ΅Ρ:
|
| 7 |
+
is_select_only(sql) β AST-ΡΡΠΎΠ²Π½Π΅Π²ΡΠΉ Π³Π²Π°ΡΠ΄Π΅ΠΉΠ» ΠΏΡΠΎΡΠΈΠ² DDL/DML
|
| 8 |
+
ΠΏΠ΅ΡΠ΅Π΄ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ΠΌ ΡΠ³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°Π½Π½ΠΎΠ³ΠΎ Π·Π°ΠΏΡΠΎΡΠ°;
|
| 9 |
+
normalize_sql(sql) β ΠΊΠ°Π½ΠΎΠ½ΠΈΡΠ΅ΡΠΊΠ°Ρ ΡΠΎΡΠΌΠ° Π΄Π»Ρ ΡΠ°ΡΡΡΡΠ° Exact Match
|
| 10 |
+
(ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠ° Ρ evaluate_pauq.py).
|
| 11 |
+
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
+
import logging
|
| 16 |
import re
|
| 17 |
|
| 18 |
import sqlglot
|
| 19 |
+
from sqlglot import exp
|
| 20 |
from sqlglot.errors import ParseError
|
| 21 |
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# ΠΠ»ΡΡΠ΅Π²ΡΠ΅ ΡΠ»ΠΎΠ²Π°, Ρ ΠΊΠΎΡΠΎΡΡΡ
ΠΌΠΎΠΆΠ΅Ρ Π½Π°ΡΠΈΠ½Π°ΡΡΡΡ ΠΊΠΎΡΡΠ΅ΠΊΡΠ½ΡΠΉ SQL-Π·Π°ΠΏΡΠΎΡ.
|
| 25 |
+
_SQL_START_KEYWORDS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE")
|
| 26 |
+
_SQL_START_REGEX = re.compile(
|
| 27 |
+
r"\b(" + "|".join(_SQL_START_KEYWORDS) + r")\b",
|
| 28 |
+
flags=re.IGNORECASE,
|
| 29 |
+
)
|
| 30 |
+
_FENCE_REGEX = re.compile(r"```(?:sql)?\s*(.*?)```", flags=re.DOTALL | re.IGNORECASE)
|
| 31 |
+
_PREFIX_REGEX = re.compile(r"^\s*(?:SQL|ΠΡΠ²Π΅Ρ|Answer)\s*:\s*", flags=re.IGNORECASE)
|
| 32 |
+
|
| 33 |
+
# Π’ΠΈΠΏΡ AST-ΡΠ·Π»ΠΎΠ², ΠΊΠΎΡΠΎΡΡΠ΅ ΠΌΡ ΡΡΠΈΡΠ°Π΅ΠΌ Β«ΠΎΡΠΌΡΡΠ»Π΅Π½Π½ΡΠΌΠΈΒ» SQL-Π·Π°ΠΏΡΠΎΡΠ°ΠΌΠΈ.
|
| 34 |
+
# sqlglot β Π»ΠΎΡΠ»ΡΠ½ΡΠΉ ΠΏΠ°ΡΡΠ΅Ρ: 'garbage text' ΠΎΠ½ ΡΠ°ΡΠΏΠ°ΡΡΠΈΡ ΠΊΠ°ΠΊ Column/Table.
|
| 35 |
+
# ΠΠ΅Π· ΠΏΡΠΎΠ²Π΅ΡΠΊΠΈ isinstance ΡΠ°ΠΊΠΈΠ΅ ΡΠ»ΡΡΠ°ΠΈ Π±ΡΠ΄ΡΡ ΠΏΡΠΎΡ
ΠΎΠ΄ΠΈΡΡ is_valid_sql.
|
| 36 |
+
_VALID_ROOT_TYPES: tuple[type[exp.Expression], ...] = (
|
| 37 |
+
exp.Select,
|
| 38 |
+
exp.With,
|
| 39 |
+
exp.Insert,
|
| 40 |
+
exp.Update,
|
| 41 |
+
exp.Delete,
|
| 42 |
+
exp.Union,
|
| 43 |
+
exp.Intersect,
|
| 44 |
+
exp.Except,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
|
| 48 |
def strip_model_artifacts(text: str) -> str:
|
| 49 |
+
"""ΠΡΠΈΡΠ°Π΅Ρ Π²ΡΠ²ΠΎΠ΄ ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΎΡ markdown ΠΈ ΠΏΠΎΡΡΠ½Π΅Π½ΠΈΠΉ Π΄ΠΎ Π½Π°ΡΠ°Π»Π° SQL-Π·Π°ΠΏΡΠΎΡΠ°.
|
| 50 |
+
|
| 51 |
+
Π¨Π°Π³ΠΈ:
|
| 52 |
+
1. ΠΡΠ»ΠΈ ΠΎΡΠ²Π΅Ρ ΠΎΠ±ΡΡΠ½ΡΡ Π² ```sql ... ``` β ΠΈΠ·Π²Π»Π΅ΠΊΠ°Π΅ΡΡΡ ΡΠΎΠ΄Π΅ΡΠΆΠΈΠΌΠΎΠ΅.
|
| 53 |
+
2. Π£Π΄Π°Π»ΡΡΡΡΡ ΠΏΡΠ΅ΡΠΈΠΊΡΡ Π²ΠΈΠ΄Π° Β«SQL:Β», Β«ΠΡΠ²Π΅Ρ:Β», Β«Answer:Β».
|
| 54 |
+
3. ΠΡΠ΅ΡΡΡ ΠΏΠ΅ΡΠ²ΠΎΠ΅ Π²Ρ
ΠΎΠΆΠ΄Π΅Π½ΠΈΠ΅ SQL-ΠΊΠ»ΡΡΠ΅Π²ΠΎΠ³ΠΎ ΡΠ»ΠΎΠ²Π°, Π²ΡΡ Π΄ΠΎ Π½Π΅Π³ΠΎ ΠΎΡΠ±ΡΠ°ΡΡΠ²Π°Π΅ΡΡΡ.
|
| 55 |
+
4. ΠΠ΅ΡΡΡΡΡ ΠΏΠ΅ΡΠ²ΡΠΉ statement Π΄ΠΎ ΠΏΠ΅ΡΠ²ΠΎΠΉ ΡΠΎΡΠΊΠΈ Ρ Π·Π°ΠΏΡΡΠΎΠΉ Π²ΠΊΠ»ΡΡΠΈΡΠ΅Π»ΡΠ½ΠΎ.
|
| 56 |
+
"""
|
| 57 |
+
fence = _FENCE_REGEX.search(text)
|
| 58 |
+
if fence:
|
| 59 |
+
text = fence.group(1)
|
| 60 |
|
| 61 |
+
text = _PREFIX_REGEX.sub("", text)
|
| 62 |
+
|
| 63 |
+
keyword_match = _SQL_START_REGEX.search(text)
|
| 64 |
+
if keyword_match:
|
| 65 |
+
text = text[keyword_match.start():]
|
| 66 |
|
|
|
|
| 67 |
text = text.strip()
|
| 68 |
if ";" in text:
|
| 69 |
head, _, _ = text.partition(";")
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool:
|
| 76 |
+
"""ΠΡΠΎΠ²Π΅ΡΡΠ΅Ρ, ΡΡΠΎ ΡΡΡΠΎΠΊΠ° β ΡΡΠΎ Π²Π°Π»ΠΈΠ΄Π½ΡΠΉ SQL-Π·Π°ΠΏΡΠΎΡ.
|
| 77 |
+
|
| 78 |
+
ΠΠ°ΡΡΠΈΡΡΡ ΡΠ΅ΡΠ΅Π· sqlglot ΠΈ Π΄ΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΠΎ ΠΏΡΠΎΠ²Π΅ΡΡΠ΅ΡΡΡ, ΡΡΠΎ ΠΊΠΎΡΠ΅Π½Ρ AST β
|
| 79 |
+
ΡΡΠΎ ΠΎΠ΄ΠΈΠ½ ΠΈΠ· Β«ΠΎΡΠΌΡΡΠ»Π΅Π½Π½ΡΡ
Β» ΡΠΈΠΏΠΎΠ² Π·Π°ΠΏΡΠΎΡΠ° (SELECT/WITH/INSERT/UPDATE/
|
| 80 |
+
DELETE/UNION). ΠΠ΅Π· ΠΏΡΠΎΠ²Π΅ΡΠΊΠΈ ΡΠΈΠΏΠ° sqlglot ΠΏΡΠΈΠ½ΠΈΠΌΠ°Π΅Ρ Π·Π° SQL Π΄Π°ΠΆΠ΅
|
| 81 |
+
ΡΠ»ΡΡΠ°ΠΉΠ½ΡΠ΅ ΠΈΠ΄Π΅Π½ΡΠΈΡΠΈΠΊΠ°ΡΠΎΡΡ, ΠΏΠΎΡΠΎΠΌΡ ΡΡΠΎ ΠΎΠ½ Π»ΠΎΡΠ»ΡΠ½ΡΠΉ ΠΏΠ°ΡΡΠ΅Ρ.
|
| 82 |
+
"""
|
| 83 |
+
if not sql or not sql.strip():
|
| 84 |
+
return False
|
| 85 |
try:
|
| 86 |
+
parsed = sqlglot.parse_one(sql, dialect=dialect)
|
| 87 |
+
except (ParseError, ValueError, TypeError) as e:
|
| 88 |
+
logger.debug("sqlglot Π½Π΅ ΡΠΌΠΎΠ³ ΡΠ°Π·ΠΎΠ±ΡΠ°ΡΡ SQL: %s", e)
|
| 89 |
+
return False
|
| 90 |
+
if parsed is None:
|
| 91 |
return False
|
| 92 |
+
return isinstance(parsed, _VALID_ROOT_TYPES)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def is_select_only(sql: str, dialect: str = "sqlite") -> bool:
|
| 96 |
+
"""ΠΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ True, Π΅ΡΠ»ΠΈ SQL β ΡΡΠΎ SELECT (Π² Ρ. Ρ. Π²Π½ΡΡΡΠΈ WITH-CTE).
|
| 97 |
+
|
| 98 |
+
ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΡΡΡ ΠΊΠ°ΠΊ guardrail ΠΏΠ΅ΡΠ΅Π΄ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ΠΌ ΡΠ³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°Π½Π½ΠΎΠ³ΠΎ Π·Π°ΠΏΡΠΎΡΠ°
|
| 99 |
+
Π½Π° ΡΠ΅Π°Π»ΡΠ½ΠΎΠΉ Π±Π°Π·Π΅ Π΄Π°Π½Π½ΡΡ
: ΠΌΠΎΠ΄Π΅Π»Ρ Π½Π΅ Π΄ΠΎΠ»ΠΆΠ½Π° ΠΏΠΎΠ»ΡΡΠΈΡΡ Π²ΠΎΠ·ΠΌΠΎΠΆΠ½ΠΎΡΡΡ Π²ΡΠ·Π²Π°ΡΡ
|
| 100 |
+
DROP/UPDATE/DELETE/INSERT, Π΄Π°ΠΆΠ΅ Π΅ΡΠ»ΠΈ ΡΠ°ΠΊΠΈΠ΅ ΠΊΠΎΠ½ΡΡΡΡΠΊΡΠΈΠΈ ΡΠΈΠ½ΡΠ°ΠΊΡΠΈΡΠ΅ΡΠΊΠΈ
|
| 101 |
+
ΠΊΠΎΡΡΠ΅ΠΊΡΠ½Ρ.
|
| 102 |
+
"""
|
| 103 |
+
if not sql or not sql.strip():
|
| 104 |
+
return False
|
| 105 |
+
try:
|
| 106 |
+
parsed = sqlglot.parse_one(sql, dialect=dialect)
|
| 107 |
+
except (ParseError, ValueError, TypeError):
|
| 108 |
+
return False
|
| 109 |
+
if parsed is None:
|
| 110 |
+
return False
|
| 111 |
+
if isinstance(parsed, exp.Select):
|
| 112 |
+
return True
|
| 113 |
+
if isinstance(parsed, exp.With):
|
| 114 |
+
return isinstance(parsed.this, exp.Select)
|
| 115 |
+
if isinstance(parsed, exp.Subquery):
|
| 116 |
+
return isinstance(parsed.this, exp.Select)
|
| 117 |
+
return False
|
| 118 |
|
| 119 |
|
| 120 |
def normalize_sql(sql: str, dialect: str = "sqlite") -> str:
|
| 121 |
+
"""ΠΠ°Π½ΠΎΠ½ΠΈΡΠ΅ΡΠΊΠ°Ρ ΡΠΎΡΠΌΠ° Π΄Π»Ρ ΡΠ°ΡΡΡΡΠ° Exact Match.
|
| 122 |
+
|
| 123 |
+
ΠΡΠΏΠΎΠ»ΡΠ·ΡΠ΅Ρ sqlglot Ρ ΡΠ»Π°Π³ΠΎΠΌ ``normalize=True`` β ΡΡΠΎ Π½ΠΎΡΠΌΠ°Π»ΠΈΠ·ΡΠ΅Ρ ΡΠ΅Π³ΠΈΡΡΡ
|
| 124 |
+
ΠΊΠ»ΡΡΠ΅Π²ΡΡ
ΡΠ»ΠΎΠ² ΠΈ ΠΈΠ΄Π΅Π½ΡΠΈΡΠΈΠΊΠ°ΡΠΎΡΠΎΠ². Π Π΅Π·ΡΠ»ΡΡΠ°Ρ ΠΏΡΠΈΠ²ΠΎΠ΄ΠΈΡΡΡ ΠΊ Π²Π΅ΡΡ
Π½Π΅ΠΌΡ ΡΠ΅Π³ΠΈΡΡΡΡ,
|
| 125 |
+
ΡΡΠΎΠ±Ρ EM ΡΡΠΈΡΠ°Π»ΡΡ ΠΈΠ΄Π΅Π½ΡΠΈΡΠ½ΠΎ ΡΡΠ°Π»ΠΎΠ½Π½ΠΎΠΉ ΡΠ΅Π°Π»ΠΈΠ·Π°ΡΠΈΠΈ Π² ``evaluate_pauq.py``.
|
| 126 |
+
"""
|
| 127 |
try:
|
| 128 |
+
parsed = sqlglot.parse_one(sql, dialect=dialect)
|
| 129 |
+
return parsed.sql(dialect=dialect, normalize=True).upper()
|
| 130 |
+
except (ParseError, ValueError, TypeError):
|
| 131 |
+
return re.sub(r"\s+", " ", sql.upper()).strip().rstrip(";")
|
| 132 |
|
| 133 |
|
| 134 |
def postprocess(raw_output: str) -> str:
|
| 135 |
+
"""ΠΠΎΠ»Π½ΡΠΉ pipeline ΠΏΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΠΈ Π²ΡΠ²ΠΎΠ΄Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ.
|
| 136 |
+
|
| 137 |
+
1. Π§ΠΈΡΡΠΊΠ° Π°ΡΡΠ΅ΡΠ°ΠΊΡΠΎΠ² ΡΠ΅ΡΠ΅Π· :func:`strip_model_artifacts`.
|
| 138 |
+
2. ΠΠ°Π»ΠΈΠ΄Π°ΡΠΈΡ ΡΠ΅ΡΠ΅Π· :func:`is_valid_sql`.
|
| 139 |
+
3. ΠΠΎΠ·Π²ΡΠ°Ρ ΠΏΡΡΡΠΎΠΉ ΡΡΡΠΎΠΊΠΈ ΠΏΡΠΈ ΠΏΡΠΎΠ²Π°Π»Π΅ Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΠΈ.
|
| 140 |
+
|
| 141 |
+
Π‘ΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΠ΅Ρ ΡΠ°Π·Π΄Π΅Π»Ρ 2.5 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ.
|
| 142 |
+
"""
|
| 143 |
+
sql = strip_model_artifacts(raw_output)
|
| 144 |
+
if not is_valid_sql(sql):
|
| 145 |
+
logger.warning("postprocess ΠΎΡΠ±ΡΠΎΡΠΈΠ» Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΡΠΉ SQL: %r", sql[:120])
|
| 146 |
+
return ""
|
| 147 |
+
return sql
|
src/streamlit_app.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
-
|
| 2 |
-
import sys
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
os.chdir(root)
|
| 7 |
-
sys.path.insert(0, root)
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Π£ΡΡΠ°ΡΠ΅Π²ΡΠΈΠΉ ΠΌΠΎΠ΄ΡΠ»Ρ.
|
|
|
|
| 2 |
|
| 3 |
+
Π’ΠΎΡΠΊΠ° Π²Ρ
ΠΎΠ΄Π° Streamlit-ΠΏΡΠΈΠ»ΠΎΠΆΠ΅Π½ΠΈΡ β ``streamlit_app.py`` Π² ΠΊΠΎΡΠ½Π΅ ΠΏΡΠΎΠ΅ΠΊΡΠ°.
|
| 4 |
+
ΠΠ°ΠΏΡΡΠΊ:
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
streamlit run streamlit_app.py
|
| 7 |
+
|
| 8 |
+
ΠΡΠΎΡ ΡΠ°ΠΉΠ» ΠΎΡΡΠ°Π²Π»Π΅Π½ ΠΏΡΡΡΡΠΌ ΡΠ°Π΄ΠΈ ΠΎΠ±ΡΠ°ΡΠ½ΠΎΠΉ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠΎΡΡΠΈ import-ΠΏΡΡΠ΅ΠΉ.
|
| 9 |
+
Π Π°Π½Π΅Π΅ Π·Π΄Π΅ΡΡ Π±ΡΠ»Π° ΠΎΠ±ΡΡΡΠΊΠ° ΡΠ΅ΡΠ΅Π· ``exec(open(...).read())``, ΠΊΠΎΡΠΎΡΠ°Ρ
|
| 10 |
+
Π·Π°ΠΏΡΡΠΊΠ°Π»Π° UI ΠΊΠ°ΠΊ ΠΏΠΎΠ±ΠΎΡΠ½ΡΠΉ ΡΡΡΠ΅ΠΊΡ ΠΈΠΌΠΏΠΎΡΡΠ° ΠΏΠ°ΠΊΠ΅ΡΠ° ``src`` β ΡΡΠΎ Π»ΠΎΠΌΠ°Π»ΠΎ
|
| 11 |
+
ΡΠ±ΠΎΡ ΡΠ΅ΡΡΠΎΠ² ΠΈ ΡΡΠ°ΡΠΈΡΠ΅ΡΠΊΠΈΠΉ Π°Π½Π°Π»ΠΈΠ·. Π₯Π°ΠΊ ΡΠ΄Π°Π»ΡΠ½.
|
| 12 |
+
"""
|
streamlit_app.py
CHANGED
|
@@ -1,16 +1,44 @@
|
|
| 1 |
-
"""Streamlit-ΠΈΠ½ΡΠ΅ΡΡΠ΅ΠΉΡ ΡΡΠΈΠ»ΠΈΡΡ Ru2SQL.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
|
|
|
| 5 |
import sys
|
| 6 |
-
import
|
| 7 |
from pathlib import Path
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import streamlit as st
|
| 10 |
|
| 11 |
ROOT = Path(__file__).resolve().parent
|
| 12 |
sys.path.insert(0, str(ROOT))
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
# ΠΠΎΠ½ΡΠΈΠ³ΡΡΠ°ΡΠΈΡ ΡΡΡΠ°Π½ΠΈΡΡ
|
| 16 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -21,11 +49,10 @@ st.set_page_config(
|
|
| 21 |
)
|
| 22 |
|
| 23 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
-
# CSS
|
| 25 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
st.markdown("""
|
| 27 |
<style>
|
| 28 |
-
/* ββ ΠΠ»ΠΎΠ±Π°Π»ΡΠ½ΡΠΉ ΡΡΠΈΡΡ ΠΈ ΡΠΎΠ½ ββ */
|
| 29 |
html, body, [data-testid="stAppViewContainer"] {
|
| 30 |
background-color: #0d1117;
|
| 31 |
font-size: 16px;
|
|
@@ -36,7 +63,6 @@ st.markdown("""
|
|
| 36 |
}
|
| 37 |
[data-testid="stHeader"] { background: transparent; }
|
| 38 |
|
| 39 |
-
/* ββ Π¨Π°ΠΏΠΊΠ° ββ */
|
| 40 |
.app-header {
|
| 41 |
padding: 32px 0 24px 0;
|
| 42 |
border-bottom: 1px solid #30363d;
|
|
@@ -57,8 +83,6 @@ st.markdown("""
|
|
| 57 |
font-weight: 400;
|
| 58 |
letter-spacing: 0.1px;
|
| 59 |
}
|
| 60 |
-
|
| 61 |
-
/* ββ Π‘Π°ΠΉΠ΄Π±Π°Ρ: ΡΠ΅ΠΊΡΠΈΠΈ ββ */
|
| 62 |
.sb-label {
|
| 63 |
font-size: 10px;
|
| 64 |
font-weight: 700;
|
|
@@ -73,32 +97,13 @@ st.markdown("""
|
|
| 73 |
border-top: 1px solid #30363d;
|
| 74 |
margin: 4px 0 0 0;
|
| 75 |
}
|
| 76 |
-
|
| 77 |
-
/* ββ Π‘ΡΠ°ΡΡΡΡ ββ */
|
| 78 |
.status-ok { color: #3fb950; font-size: 13px; font-weight: 600; }
|
| 79 |
.status-err { color: #f85149; font-size: 13px; font-weight: 600; }
|
| 80 |
-
|
| 81 |
-
/* ββ DB-ΠΏΠ΅ΡΠ΅ΠΊΠ»ΡΡΠ°ΡΠ΅Π»Ρ ββ */
|
| 82 |
-
div[data-testid="stRadio"] > div {
|
| 83 |
-
gap: 4px;
|
| 84 |
-
}
|
| 85 |
-
div[data-testid="stRadio"] > div > label {
|
| 86 |
-
font-size: 14px;
|
| 87 |
-
padding: 4px 0;
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
/* ββ ΠΠ½ΠΎΠΏΠΊΠ° ΡΠ»ΠΎΠ²Π°ΡΡ ββ */
|
| 91 |
-
.vocab-status {
|
| 92 |
-
font-size: 12px;
|
| 93 |
-
color: #7d8590;
|
| 94 |
-
margin-top: 6px;
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
/* ββ SQL-Π±Π»ΠΎΠΊ ββ */
|
| 98 |
.sql-box {
|
| 99 |
background: #161b22;
|
| 100 |
color: #e6edf3;
|
| 101 |
-
font-family: 'JetBrains Mono', 'Fira Code',
|
| 102 |
font-size: 14px;
|
| 103 |
line-height: 1.7;
|
| 104 |
padding: 20px 24px;
|
|
@@ -108,11 +113,7 @@ st.markdown("""
|
|
| 108 |
white-space: pre-wrap;
|
| 109 |
margin: 14px 0;
|
| 110 |
}
|
| 111 |
-
|
| 112 |
-
/* ββ ΠΠΊΠ»Π°Π΄ΠΊΠΈ ββ */
|
| 113 |
[data-testid="stTabs"] button { font-size: 15px; font-weight: 500; }
|
| 114 |
-
|
| 115 |
-
/* ββ ΠΡΠΈΠΌΠ΅ΡΡ Π·Π°ΠΏΡΠΎΡΠΎΠ² ββ */
|
| 116 |
.examples-label {
|
| 117 |
font-size: 11px;
|
| 118 |
font-weight: 700;
|
|
@@ -121,47 +122,21 @@ st.markdown("""
|
|
| 121 |
color: #7d8590;
|
| 122 |
margin: 24px 0 10px 0;
|
| 123 |
}
|
| 124 |
-
|
| 125 |
-
/* ββ ΠΠΎΠ»Π΅ Π²Π²ΠΎΠ΄Π° Π²ΠΎΠΏΡΠΎΡΠ° ββ */
|
| 126 |
[data-testid="stTextArea"] textarea {
|
| 127 |
font-size: 16px !important;
|
| 128 |
line-height: 1.6 !important;
|
| 129 |
}
|
| 130 |
-
|
| 131 |
-
/* ββ ΠΠ½ΠΎΠΏΠΊΠ° Β«ΠΡΠΏΠΎΠ»Π½ΠΈΡΡΒ» ββ */
|
| 132 |
[data-testid="stButton"] > button[kind="primary"] {
|
| 133 |
font-size: 15px;
|
| 134 |
padding: 10px 28px;
|
| 135 |
border-radius: 8px;
|
| 136 |
font-weight: 600;
|
| 137 |
}
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
[data-testid="
|
| 141 |
-
|
| 142 |
-
color: #7d8590 !important;
|
| 143 |
-
}
|
| 144 |
-
[data-testid="stMetricValue"] {
|
| 145 |
-
font-size: 22px !important;
|
| 146 |
-
color: #e6edf3 !important;
|
| 147 |
-
}
|
| 148 |
-
|
| 149 |
-
/* ββ ΠΡΠ΅Π΄ΡΠΏΡΠ΅ΠΆΠ΄Π΅Π½ΠΈΠ΅ ΠΎ Π½Π΅Π³ΠΎΡΠΎΠ²Π½ΠΎΡΡΠΈ ββ */
|
| 150 |
-
[data-testid="stAlertContainer"] {
|
| 151 |
-
border-radius: 8px;
|
| 152 |
-
font-size: 14px;
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
/* ββ Expander ΡΡ
Π΅ΠΌΡ ββ */
|
| 156 |
-
[data-testid="stExpander"] summary {
|
| 157 |
-
font-size: 15px;
|
| 158 |
-
font-weight: 500;
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
/* ββ Π‘ΠΊΡΡΡΡ ΠΊΠ½ΠΎΠΏΠΊΡ Stop ββ */
|
| 162 |
button[kind="stop"] { display: none !important; }
|
| 163 |
-
|
| 164 |
-
/* ββ ΠΠΎΠ΄Π°Π»ΡΠ½ΡΠΉ Π΄ΠΈΠ°Π»ΠΎΠ³ ΡΠ»ΠΎΠ²Π°ΡΡ ββ */
|
| 165 |
[data-testid="stDialog"] textarea {
|
| 166 |
font-family: 'JetBrains Mono', 'Fira Code', monospace !important;
|
| 167 |
font-size: 13px !important;
|
|
@@ -178,27 +153,21 @@ def _default_vocab_yaml() -> str:
|
|
| 178 |
example = ROOT / "configs" / "example_vocabulary.yaml"
|
| 179 |
if example.exists():
|
| 180 |
return example.read_text(encoding="utf-8")
|
| 181 |
-
return
|
| 182 |
-
"company: ΠΠΎΡ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΡ\n\n"
|
| 183 |
-
"terms:\n"
|
| 184 |
-
" Π²ΡΡΡΡΠΊΠ°: SUM(orders.amount) WHERE status = 'paid'\n\n"
|
| 185 |
-
"filters:\n"
|
| 186 |
-
" ΡΠΎΠ»ΡΠΊΠΎ_ΠΎΠΏΠ»Π°ΡΠ΅Π½Π½ΡΠ΅: orders.status = 'paid'\n\n"
|
| 187 |
-
"notes: []\n"
|
| 188 |
-
)
|
| 189 |
|
| 190 |
|
| 191 |
def _init_state():
|
| 192 |
defaults = {
|
| 193 |
-
"history":
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
-
"
|
| 197 |
-
"
|
| 198 |
-
"
|
| 199 |
-
"
|
| 200 |
-
"vocab_yaml":
|
| 201 |
-
"db_mode":
|
|
|
|
| 202 |
}
|
| 203 |
for k, v in defaults.items():
|
| 204 |
if k not in st.session_state:
|
|
@@ -209,58 +178,90 @@ _init_state()
|
|
| 209 |
|
| 210 |
|
| 211 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
-
#
|
| 213 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
-
def
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
import tempfile
|
| 230 |
-
from src.business.vocabulary import BusinessVocabulary
|
| 231 |
tmp = Path(tempfile.mktemp(suffix=".yaml"))
|
| 232 |
tmp.write_text(yaml_text, encoding="utf-8")
|
| 233 |
-
vocab = BusinessVocabulary.from_yaml(tmp)
|
| 234 |
-
tmp.unlink(missing_ok=True)
|
| 235 |
-
return vocab
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
def _auto_connect_demo():
|
| 239 |
-
"""ΠΠΎΠ΄ΠΊΠ»ΡΡΠΈΡΡ Π΄Π΅ΠΌΠΎ-Π±Π°Π·Ρ ΠΈ ΡΠ»ΠΎΠ²Π°ΡΡ ΠΊ Π½Π΅ΠΉ."""
|
| 240 |
-
demo_path = ROOT / "data" / "demo" / "sales.sqlite"
|
| 241 |
-
cs = str(demo_path)
|
| 242 |
try:
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
st.session_state.db_connection_string = cs
|
| 247 |
-
if st.session_state.vocabulary is None:
|
| 248 |
-
vocab_path = ROOT / "configs" / "example_vocabulary.yaml"
|
| 249 |
-
if vocab_path.exists():
|
| 250 |
-
st.session_state.vocabulary = _load_vocab_from_yaml(
|
| 251 |
-
vocab_path.read_text(encoding="utf-8")
|
| 252 |
-
)
|
| 253 |
-
except Exception:
|
| 254 |
-
pass
|
| 255 |
|
| 256 |
|
| 257 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 258 |
-
#
|
| 259 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 260 |
@st.dialog("ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ", width="large")
|
| 261 |
def vocab_dialog():
|
| 262 |
st.caption(
|
| 263 |
-
"ΠΠΏΠΈΡΠΈΡΠ΅ ΡΠ΅ΡΠΌΠΈΠ½Ρ
|
| 264 |
"ΠΠΎΠ΄Π΅Π»Ρ Π±ΡΠ΄Π΅Ρ ΡΡΠΈΡΡΠ²Π°ΡΡ ΠΈΡ
ΠΏΡΠΈ Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠΈ SQL."
|
| 265 |
)
|
| 266 |
yaml_text = st.text_area(
|
|
@@ -269,77 +270,72 @@ def vocab_dialog():
|
|
| 269 |
height=480,
|
| 270 |
label_visibility="collapsed",
|
| 271 |
)
|
| 272 |
-
|
| 273 |
-
with
|
| 274 |
-
if st.button("ΠΡΠΈΠΌΠ΅Π½ΠΈΡΡ", type="primary",
|
| 275 |
try:
|
| 276 |
-
|
| 277 |
-
st.session_state.vocabulary = vocab
|
| 278 |
st.session_state.vocab_yaml = yaml_text
|
| 279 |
st.rerun()
|
| 280 |
except Exception as e:
|
| 281 |
st.error(f"ΠΡΠΈΠ±ΠΊΠ° ΡΠΈΠ½ΡΠ°ΠΊΡΠΈΡΠ° YAML: {e}")
|
| 282 |
-
with
|
| 283 |
-
if st.button("ΠΡΠΌΠ΅Π½Π°",
|
| 284 |
st.rerun()
|
| 285 |
|
| 286 |
|
| 287 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 288 |
-
#
|
| 289 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
with st.sidebar:
|
| 291 |
|
| 292 |
-
# ββ
|
| 293 |
-
st.markdown('<p class="sb-label">
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
if st.session_state.model_loaded:
|
| 303 |
-
st.markdown(
|
| 304 |
-
'<span class="status-ok">β
Qwen2.5-Coder-3B + QLoRA</span>',
|
| 305 |
-
unsafe_allow_html=True,
|
| 306 |
-
)
|
| 307 |
else:
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
st.markdown('<hr class="sb-divider">', unsafe_allow_html=True)
|
| 314 |
|
| 315 |
# ββ ΠΠ°Π·Π° Π΄Π°Π½Π½ΡΡ
ββ
|
| 316 |
st.markdown('<p class="sb-label">ΠΠ°Π·Π° Π΄Π°Π½Π½ΡΡ
</p>', unsafe_allow_html=True)
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
db_mode = st.radio(
|
| 321 |
-
"ΠΡΡΠΎΡΠ½ΠΈΠΊ Π΄Π°Π½Π½ΡΡ
",
|
| 322 |
-
|
| 323 |
-
index=_modes.index(_prev) if _prev in _modes else None,
|
| 324 |
label_visibility="collapsed",
|
| 325 |
)
|
| 326 |
-
if db_mode !=
|
| 327 |
-
|
| 328 |
-
st.session_state.
|
| 329 |
-
st.session_state.db_executor = None
|
| 330 |
st.session_state.db_mode = db_mode
|
| 331 |
|
| 332 |
cs = ""
|
| 333 |
-
|
| 334 |
if db_mode == "ΠΠ΅ΠΌΠΎ-Π±Π°Π·Π°":
|
| 335 |
st.caption("ΠΡΡΡΠΎΠ΅Π½Π½Π°Ρ Π±Π°Π·Π°: ΠΈΠ½ΡΠ΅ΡΠ½Π΅Ρ-ΠΌΠ°Π³Π°Π·ΠΈΠ½ ΡΠ»Π΅ΠΊΡΡΠΎΠ½ΠΈΠΊΠΈ, 120 Π·Π°ΠΊΠ°Π·ΠΎΠ².")
|
| 336 |
-
|
| 337 |
-
cs = str(demo_path)
|
| 338 |
-
|
| 339 |
elif db_mode == "ΠΠ°Π³ΡΡΠ·ΠΈΡΡ ΡΠ°ΠΉΠ»":
|
| 340 |
uploaded = st.file_uploader(
|
| 341 |
-
"SQLite-ΡΠ°ΠΉΠ» Π±Π°Π·Ρ Π΄Π°Π½Π½ΡΡ
",
|
| 342 |
-
type=["sqlite", "db"],
|
| 343 |
label_visibility="collapsed",
|
| 344 |
)
|
| 345 |
if uploaded:
|
|
@@ -349,81 +345,77 @@ with st.sidebar:
|
|
| 349 |
cs = str(tmp_db)
|
| 350 |
else:
|
| 351 |
st.caption("ΠΠ΅ΡΠ΅ΡΠ°ΡΠΈΡΠ΅ .sqlite ΠΈΠ»ΠΈ .db ΡΠ°ΠΉΠ» ΡΡΠ΄Π°")
|
| 352 |
-
|
| 353 |
-
else: # Π‘ΡΡΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ
|
| 354 |
cs = st.text_input(
|
| 355 |
"Π‘ΡΡΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ",
|
| 356 |
placeholder="postgresql://user:pass@host:5432/db",
|
| 357 |
-
value=st.session_state.
|
| 358 |
label_visibility="collapsed",
|
| 359 |
)
|
| 360 |
-
st.caption("PostgreSQL Β· MySQL
|
| 361 |
-
|
| 362 |
-
if cs and st.button("ΠΠΎΠ΄ΠΊΠ»ΡΡΠΈΡΡΡΡ",
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
st.
|
| 367 |
-
st.session_state.
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
if "sales" in cs and st.session_state.vocabulary is None:
|
| 370 |
-
|
| 371 |
-
if
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
| 375 |
st.success(f"ΠΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΎ. Π’Π°Π±Π»ΠΈΡ: {len(tables)}")
|
| 376 |
-
except Exception as e:
|
| 377 |
-
st.error(f"ΠΡΠΈΠ±ΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ: {e}")
|
| 378 |
|
| 379 |
-
if st.session_state.
|
| 380 |
-
|
| 381 |
st.markdown(
|
| 382 |
'<span class="status-ok">β
ΠΠ°Π·Π° Π΄Π°Π½Π½ΡΡ
ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½Π°</span>',
|
| 383 |
unsafe_allow_html=True,
|
| 384 |
)
|
| 385 |
-
with st.expander(f"Π’Π°Π±Π»ΠΈΡΡ ({
|
| 386 |
-
for t in
|
| 387 |
-
st.code(t, language=None)
|
| 388 |
|
| 389 |
st.markdown('<hr class="sb-divider">', unsafe_allow_html=True)
|
| 390 |
|
| 391 |
# ββ ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ ββ
|
| 392 |
st.markdown('<p class="sb-label">ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ</p>', unsafe_allow_html=True)
|
| 393 |
-
|
| 394 |
if st.session_state.vocabulary:
|
| 395 |
v = st.session_state.vocabulary
|
| 396 |
label = v.company if v.company else "ΠΠ°Π³ΡΡΠΆΠ΅Π½"
|
| 397 |
-
st.markdown(
|
| 398 |
-
f'<span class="status-ok">β
{label}</span>',
|
| 399 |
-
unsafe_allow_html=True,
|
| 400 |
-
)
|
| 401 |
if v.terms:
|
| 402 |
-
st.
|
| 403 |
-
f'<span class="vocab-status">Π’Π΅ΡΠΌΠΈΠ½ΠΎΠ²: {len(v.terms)}</span>',
|
| 404 |
-
unsafe_allow_html=True,
|
| 405 |
-
)
|
| 406 |
else:
|
| 407 |
-
st.
|
| 408 |
-
|
| 409 |
-
unsafe_allow_html=True,
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
-
if st.button("Π Π΅Π΄Π°ΠΊΡΠΈΡΠΎΠ²Π°ΡΡ ΡΠ»ΠΎΠ²Π°ΡΡ", use_container_width=True):
|
| 413 |
vocab_dialog()
|
| 414 |
|
| 415 |
|
| 416 |
-
|
| 417 |
-
|
| 418 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 419 |
-
#
|
| 420 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 421 |
st.markdown("""
|
| 422 |
<div class="app-header">
|
| 423 |
<p class="app-title">Ru2SQL β Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠ²Π½Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΡ Π·Π°ΠΏΡΠΎΡΠΎΠ²<br>
|
| 424 |
ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½ΡΡ
Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² Π·Π°ΠΏΡΠΎΡΡ Π½Π° ΡΠ·ΡΠΊΠ΅ SQL</p>
|
| 425 |
<p class="app-subtitle">
|
| 426 |
-
Qwen2.5-Coder-3B-Instruct Β· QLoRA
|
| 427 |
Β· SQLite / PostgreSQL / MySQL
|
| 428 |
</p>
|
| 429 |
</div>
|
|
@@ -431,15 +423,21 @@ st.markdown("""
|
|
| 431 |
|
| 432 |
tab_query, tab_schema, tab_history = st.tabs(["ΠΠ°ΠΏΡΠΎΡ", "Π‘Ρ
Π΅ΠΌΠ° Π±Π°Π·Ρ Π΄Π°Π½Π½ΡΡ
", "ΠΡΡΠΎΡΠΈΡ"])
|
| 433 |
|
| 434 |
-
|
|
|
|
| 435 |
with tab_query:
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
if not ready:
|
| 439 |
missing = []
|
| 440 |
-
if not
|
| 441 |
-
missing.append("ΠΌΠΎΠ΄Π΅Π»Ρ
|
| 442 |
-
if
|
| 443 |
missing.append("Π±Π°Π·Π° Π΄Π°Π½Π½ΡΡ
Π½Π΅ ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½Π°")
|
| 444 |
st.warning("Π‘ΠΈΡΡΠ΅ΠΌΠ° Π½Π΅ Π³ΠΎΡΠΎΠ²Π°: " + ", ".join(missing) + ". ΠΡΠΏΠΎΠ»ΡΠ·ΡΠΉΡΠ΅ ΠΏΠ°Π½Π΅Π»Ρ ΡΠ»Π΅Π²Π°.")
|
| 445 |
|
|
@@ -450,19 +448,19 @@ with tab_query:
|
|
| 450 |
disabled=not ready,
|
| 451 |
)
|
| 452 |
|
| 453 |
-
col_btn,
|
| 454 |
with col_btn:
|
| 455 |
run_btn = st.button(
|
| 456 |
"ΠΡΠΏΠΎΠ»Π½ΠΈΡΡ",
|
| 457 |
type="primary",
|
| 458 |
disabled=not ready or not question.strip(),
|
| 459 |
-
|
| 460 |
)
|
| 461 |
|
| 462 |
# ΠΡΠΈΠΌΠ΅ΡΡ Π΄Π»Ρ Π΄Π΅ΠΌΠΎ-Π±Π°Π·Ρ
|
| 463 |
if (
|
| 464 |
-
st.session_state.
|
| 465 |
-
and "sales" in st.session_state.
|
| 466 |
):
|
| 467 |
st.markdown('<p class="examples-label">ΠΡΠΈΠΌΠ΅ΡΡ Π·Π°ΠΏΡΠΎΡΠΎΠ²</p>', unsafe_allow_html=True)
|
| 468 |
ex_cols = st.columns(3)
|
|
@@ -473,78 +471,76 @@ with tab_query:
|
|
| 473 |
]
|
| 474 |
for i, ex in enumerate(examples):
|
| 475 |
with ex_cols[i]:
|
| 476 |
-
if st.button(ex, key=f"ex_{i}",
|
| 477 |
question = ex
|
| 478 |
-
run_btn
|
| 479 |
|
| 480 |
if run_btn and question.strip():
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
executor = st.session_state.db_executor
|
| 484 |
-
vocab = st.session_state.vocabulary
|
| 485 |
-
|
| 486 |
-
enriched = vocab.enrich_prompt(question) if vocab else question
|
| 487 |
-
schema = connector.render_schema(include_samples=True)
|
| 488 |
|
| 489 |
-
with st.spinner("
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
|
|
|
|
|
|
| 493 |
|
| 494 |
st.markdown("**Π‘Π³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°Π½Π½ΡΠΉ SQL**")
|
| 495 |
-
st.markdown(f'<div class="sql-box">{
|
| 496 |
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
qr = executor.run(result.sql)
|
| 501 |
|
| 502 |
c1, c2, c3 = st.columns(3)
|
| 503 |
c1.metric("ΠΡΠ΅ΠΌΡ Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠΈ", f"{gen_time:.1f} Ρ")
|
| 504 |
-
if
|
| 505 |
-
c2.metric("Π‘ΡΡΠΎΠΊ ΠΏΠΎΠ»ΡΡΠ΅Π½ΠΎ",
|
| 506 |
-
c3.metric("Π‘ΡΠ°ΡΡΡ", "Π£ΡΠΏΠ΅ΡΠ½ΠΎ"
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
elif
|
| 517 |
-
st.
|
|
|
|
|
|
|
| 518 |
|
| 519 |
st.session_state.history.append({
|
| 520 |
"question": question,
|
| 521 |
-
"sql":
|
| 522 |
-
"success":
|
| 523 |
-
"rows":
|
| 524 |
"time": gen_time,
|
| 525 |
})
|
| 526 |
|
| 527 |
-
|
|
|
|
| 528 |
with tab_schema:
|
| 529 |
-
if st.session_state.
|
| 530 |
st.info("ΠΠΎΠ΄ΠΊΠ»ΡΡΠΈΡΠ΅ΡΡ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½ΡΡ
ΡΠ΅ΡΠ΅Π· ΠΏΠ°Π½Π΅Π»Ρ ΡΠ»Π΅Π²Π°.")
|
| 531 |
else:
|
| 532 |
-
connector = st.session_state.db_connector
|
| 533 |
show_samples = st.toggle("ΠΠΎΠΊΠ°Π·ΡΠ²Π°ΡΡ ΠΏΡΠΈΠΌΠ΅ΡΡ Π΄Π°Π½Π½ΡΡ
", value=True)
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
if show_samples and table.sample_rows:
|
| 539 |
import pandas as pd
|
| 540 |
-
cols = [c
|
| 541 |
st.caption("ΠΡΠΈΠΌΠ΅ΡΡ Π΄Π°Π½Π½ΡΡ
:")
|
| 542 |
st.dataframe(
|
| 543 |
-
pd.DataFrame(
|
| 544 |
-
|
| 545 |
)
|
| 546 |
|
| 547 |
-
|
|
|
|
| 548 |
with tab_history:
|
| 549 |
history = st.session_state.history
|
| 550 |
if not history:
|
|
@@ -554,7 +550,7 @@ with tab_history:
|
|
| 554 |
with col_h:
|
| 555 |
st.markdown(f"**ΠΠ°ΠΏΡΠΎΡΠΎΠ² Π² ΡΠ΅ΡΡΠΈΠΈ: {len(history)}**")
|
| 556 |
with col_clr:
|
| 557 |
-
if st.button("ΠΡΠΈΡΡΠΈΡΡ",
|
| 558 |
st.session_state.history = []
|
| 559 |
st.rerun()
|
| 560 |
|
|
|
|
| 1 |
+
"""Streamlit-ΠΈΠ½ΡΠ΅ΡΡΠ΅ΠΉΡ ΡΡΠΈΠ»ΠΈΡΡ Ru2SQL.
|
| 2 |
+
|
| 3 |
+
ΠΡΡ
ΠΈΡΠ΅ΠΊΡΡΡΠ½ΠΎ β ΠΊΠ»ΠΈΠ΅Π½Ρ REST API Π½Π° FastAPI. Π‘ΠΎΠΎΡΠ²Π΅ΡΡΡΠ²ΡΠ΅Ρ ΡΠ°Π·Π΄Π΅Π»Ρ 3.5
|
| 4 |
+
ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ: Π²ΡΠ΅ ΠΎΠ±ΡΠ°ΡΠ΅Π½ΠΈΡ ΠΊ ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈ Π±Π°Π·Π΅ Π΄Π°Π½Π½ΡΡ
ΠΈΠ΄ΡΡ ΡΠ΅ΡΠ΅Π·
|
| 5 |
+
HTTP ΠΊ ``src.api.main:app``. ΠΠ°ΠΏΡΡΠΊ Π΄Π²ΡΡ
ΠΏΡΠΎΡΠ΅ΡΡΠΎΠ²:
|
| 6 |
+
|
| 7 |
+
uvicorn src.api.main:app --reload # Π½Π° 127.0.0.1:8000
|
| 8 |
+
streamlit run streamlit_app.py # Π½Π° 127.0.0.1:8501
|
| 9 |
+
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
import sys
|
| 16 |
+
import warnings
|
| 17 |
from pathlib import Path
|
| 18 |
|
| 19 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
# ΠΠ»ΡΡΠΈΠΌ ΡΡΠΌΠ½ΡΠ΅ warning'ΠΈ
|
| 21 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
# Streamlit-watcher Ρ
ΠΎΠ΄ΠΈΡ ΠΏΠΎ Π²ΡΠ΅ΠΌΡ ΠΏΠ°ΠΊΠ΅ΡΡ transformers (image-processors)
|
| 23 |
+
# ΠΈ ΡΠΏΠ°ΠΌΠΈΡ ModuleNotFoundError ΠΏΡΠΎ torchvision. ΠΠ° ΡΠ°Π±ΠΎΡΡ ΡΡΠΎ Π½Π΅ Π²Π»ΠΈΡΠ΅Ρ β
|
| 24 |
+
# Qwen2.5-Coder text-only, torchvision Π½Π΅ Π½ΡΠΆΠ΅Π½.
|
| 25 |
+
warnings.filterwarnings("ignore", message=".*torchvision.*")
|
| 26 |
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 27 |
+
logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR)
|
| 28 |
+
|
| 29 |
+
import httpx
|
| 30 |
import streamlit as st
|
| 31 |
|
| 32 |
ROOT = Path(__file__).resolve().parent
|
| 33 |
sys.path.insert(0, str(ROOT))
|
| 34 |
|
| 35 |
+
# ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ ΠΏΠ°ΡΡΠΈΠΌ Π»ΠΎΠΊΠ°Π»ΡΠ½ΠΎ β ΠΎΠ½ Π½Π΅ ΡΡΠ΅Π±ΡΠ΅Ρ ΠΎΠ±ΡΠ°ΡΠ΅Π½ΠΈΡ ΠΊ ΡΠ΅ΡΠ²Π΅ΡΡ
|
| 36 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 37 |
+
|
| 38 |
+
API_URL = os.environ.get("RU2SQL_API_URL", "http://127.0.0.1:8000")
|
| 39 |
+
QUERY_TIMEOUT = 1800.0 # 30 ΠΌΠΈΠ½ΡΡ β ΡΠ°ΠΊΡΠΈΡΠ΅ΡΠΊΠΈ Π±Π΅Π·Π»ΠΈΠΌΠΈΡ
|
| 40 |
+
SHORT_TIMEOUT = 10.0 # Π΄Π»Ρ /health, /schema
|
| 41 |
+
|
| 42 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
# ΠΠΎΠ½ΡΠΈΠ³ΡΡΠ°ΡΠΈΡ ΡΡΡΠ°Π½ΠΈΡΡ
|
| 44 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 49 |
)
|
| 50 |
|
| 51 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
# CSS β ΠΎΡΠΎΡΠΌΠ»Π΅Π½ΠΈΠ΅ Π² ΡΡΠΈΠ»Π΅ ΡΡΠΌΠ½ΠΎΠΉ ΡΠ΅ΠΌΡ GitHub
|
| 53 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
st.markdown("""
|
| 55 |
<style>
|
|
|
|
| 56 |
html, body, [data-testid="stAppViewContainer"] {
|
| 57 |
background-color: #0d1117;
|
| 58 |
font-size: 16px;
|
|
|
|
| 63 |
}
|
| 64 |
[data-testid="stHeader"] { background: transparent; }
|
| 65 |
|
|
|
|
| 66 |
.app-header {
|
| 67 |
padding: 32px 0 24px 0;
|
| 68 |
border-bottom: 1px solid #30363d;
|
|
|
|
| 83 |
font-weight: 400;
|
| 84 |
letter-spacing: 0.1px;
|
| 85 |
}
|
|
|
|
|
|
|
| 86 |
.sb-label {
|
| 87 |
font-size: 10px;
|
| 88 |
font-weight: 700;
|
|
|
|
| 97 |
border-top: 1px solid #30363d;
|
| 98 |
margin: 4px 0 0 0;
|
| 99 |
}
|
|
|
|
|
|
|
| 100 |
.status-ok { color: #3fb950; font-size: 13px; font-weight: 600; }
|
| 101 |
.status-err { color: #f85149; font-size: 13px; font-weight: 600; }
|
| 102 |
+
.status-warn { color: #d29922; font-size: 13px; font-weight: 600; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
.sql-box {
|
| 104 |
background: #161b22;
|
| 105 |
color: #e6edf3;
|
| 106 |
+
font-family: 'JetBrains Mono', 'Fira Code', monospace;
|
| 107 |
font-size: 14px;
|
| 108 |
line-height: 1.7;
|
| 109 |
padding: 20px 24px;
|
|
|
|
| 113 |
white-space: pre-wrap;
|
| 114 |
margin: 14px 0;
|
| 115 |
}
|
|
|
|
|
|
|
| 116 |
[data-testid="stTabs"] button { font-size: 15px; font-weight: 500; }
|
|
|
|
|
|
|
| 117 |
.examples-label {
|
| 118 |
font-size: 11px;
|
| 119 |
font-weight: 700;
|
|
|
|
| 122 |
color: #7d8590;
|
| 123 |
margin: 24px 0 10px 0;
|
| 124 |
}
|
|
|
|
|
|
|
| 125 |
[data-testid="stTextArea"] textarea {
|
| 126 |
font-size: 16px !important;
|
| 127 |
line-height: 1.6 !important;
|
| 128 |
}
|
|
|
|
|
|
|
| 129 |
[data-testid="stButton"] > button[kind="primary"] {
|
| 130 |
font-size: 15px;
|
| 131 |
padding: 10px 28px;
|
| 132 |
border-radius: 8px;
|
| 133 |
font-weight: 600;
|
| 134 |
}
|
| 135 |
+
[data-testid="stMetric"] label { font-size: 12px !important; color: #7d8590 !important; }
|
| 136 |
+
[data-testid="stMetricValue"] { font-size: 22px !important; color: #e6edf3 !important; }
|
| 137 |
+
[data-testid="stAlertContainer"] { border-radius: 8px; font-size: 14px; }
|
| 138 |
+
[data-testid="stExpander"] summary { font-size: 15px; font-weight: 500; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
button[kind="stop"] { display: none !important; }
|
|
|
|
|
|
|
| 140 |
[data-testid="stDialog"] textarea {
|
| 141 |
font-family: 'JetBrains Mono', 'Fira Code', monospace !important;
|
| 142 |
font-size: 13px !important;
|
|
|
|
| 153 |
example = ROOT / "configs" / "example_vocabulary.yaml"
|
| 154 |
if example.exists():
|
| 155 |
return example.read_text(encoding="utf-8")
|
| 156 |
+
return "company: ΠΠΎΡ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΡ\n\nterms: {}\nfilters: {}\nnotes: []\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
def _init_state():
|
| 160 |
defaults = {
|
| 161 |
+
"history": [],
|
| 162 |
+
"api_health": None, # dict | None
|
| 163 |
+
"api_error": None, # str | None
|
| 164 |
+
"connection_string": "",
|
| 165 |
+
"schema_tables": None, # list[TablePayload-like dict] | None
|
| 166 |
+
"schema_error": None,
|
| 167 |
+
"vocabulary": None, # BusinessVocabulary | None
|
| 168 |
+
"vocab_yaml": _default_vocab_yaml(),
|
| 169 |
+
"db_mode": None,
|
| 170 |
+
"warmup_done": False,
|
| 171 |
}
|
| 172 |
for k, v in defaults.items():
|
| 173 |
if k not in st.session_state:
|
|
|
|
| 178 |
|
| 179 |
|
| 180 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 181 |
+
# ΠΠ±ΡΡΡΠΊΠΈ Π½Π°Π΄ API
|
| 182 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 183 |
+
def _api_get_health() -> dict | None:
|
| 184 |
+
"""GET /health. None Π΅ΡΠ»ΠΈ API Π½Π΅Π΄ΠΎΡΡΡΠΏΠ΅Π½."""
|
| 185 |
+
try:
|
| 186 |
+
r = httpx.get(f"{API_URL}/health", timeout=SHORT_TIMEOUT)
|
| 187 |
+
r.raise_for_status()
|
| 188 |
+
return r.json()
|
| 189 |
+
except Exception as e:
|
| 190 |
+
st.session_state.api_error = str(e)
|
| 191 |
+
return None
|
| 192 |
|
| 193 |
|
| 194 |
+
def _api_get_schema(cs: str) -> tuple[list[dict] | None, str | None]:
|
| 195 |
+
"""POST /schema. ΠΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ (tables, error)."""
|
| 196 |
+
try:
|
| 197 |
+
r = httpx.post(
|
| 198 |
+
f"{API_URL}/schema",
|
| 199 |
+
json={"connection_string": cs, "include_samples": True},
|
| 200 |
+
timeout=SHORT_TIMEOUT,
|
| 201 |
+
)
|
| 202 |
+
if r.status_code != 200:
|
| 203 |
+
try:
|
| 204 |
+
return None, r.json().get("detail", r.text)
|
| 205 |
+
except Exception:
|
| 206 |
+
return None, r.text
|
| 207 |
+
return r.json().get("tables", []), None
|
| 208 |
+
except Exception as e:
|
| 209 |
+
return None, str(e)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _api_query(question: str, cs: str, vocab: BusinessVocabulary | None) -> dict:
|
| 213 |
+
"""POST /query β Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ SQL + ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΠΎΠ΅ ΠΈΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅."""
|
| 214 |
+
payload = {
|
| 215 |
+
"question": question,
|
| 216 |
+
"connection_string": cs,
|
| 217 |
+
"execute": True,
|
| 218 |
+
}
|
| 219 |
+
if vocab is not None and bool(vocab):
|
| 220 |
+
payload["vocabulary"] = {
|
| 221 |
+
"company": vocab.company,
|
| 222 |
+
"terms": vocab.terms,
|
| 223 |
+
"filters": vocab.filters,
|
| 224 |
+
"notes": vocab.notes,
|
| 225 |
+
}
|
| 226 |
+
r = httpx.post(f"{API_URL}/query", json=payload, timeout=QUERY_TIMEOUT)
|
| 227 |
+
if r.status_code != 200:
|
| 228 |
+
try:
|
| 229 |
+
detail = r.json().get("detail", r.text)
|
| 230 |
+
except Exception:
|
| 231 |
+
detail = r.text
|
| 232 |
+
raise RuntimeError(f"API Π²Π΅ΡΠ½ΡΠ» {r.status_code}: {detail}")
|
| 233 |
+
return r.json()
|
| 234 |
|
| 235 |
|
| 236 |
+
|
| 237 |
+
def _api_warmup() -> tuple[bool, str | None]:
|
| 238 |
+
"""POST /warmup β ΠΊΠΎΡΠΎΡΠΊΠΈΠΉ ΠΏΡΠΎΠ³ΠΎΠ½ Π΄Π»Ρ ΠΏΡΠΎΠ³ΡΠ΅Π²Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ Π½Π° CPU."""
|
| 239 |
+
try:
|
| 240 |
+
r = httpx.post(f"{API_URL}/warmup", timeout=QUERY_TIMEOUT)
|
| 241 |
+
if r.status_code == 200:
|
| 242 |
+
return True, None
|
| 243 |
+
return False, r.text
|
| 244 |
+
except Exception as e:
|
| 245 |
+
return False, str(e)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _load_vocab_from_yaml(yaml_text: str) -> BusinessVocabulary:
|
| 249 |
import tempfile
|
|
|
|
| 250 |
tmp = Path(tempfile.mktemp(suffix=".yaml"))
|
| 251 |
tmp.write_text(yaml_text, encoding="utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
try:
|
| 253 |
+
return BusinessVocabulary.from_yaml(tmp)
|
| 254 |
+
finally:
|
| 255 |
+
tmp.unlink(missing_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
|
| 258 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 259 |
+
# ΠΠΈΠ°Π»ΠΎΠ³ ΡΠ΅Π΄Π°ΠΊΡΠΈΡΠΎΠ²Π°Π½ΠΈΡ Π±ΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ
|
| 260 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 261 |
@st.dialog("ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ", width="large")
|
| 262 |
def vocab_dialog():
|
| 263 |
st.caption(
|
| 264 |
+
"ΠΠΏΠΈΡΠΈΡΠ΅ ΡΠ΅ΡΠΌΠΈΠ½Ρ ΠΈ ΠΌΠ΅ΡΡΠΈΠΊΠΈ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ Π² ΡΠΎΡΠΌΠ°ΡΠ΅ YAML. "
|
| 265 |
"ΠΠΎΠ΄Π΅Π»Ρ Π±ΡΠ΄Π΅Ρ ΡΡΠΈΡΡΠ²Π°ΡΡ ΠΈΡ
ΠΏΡΠΈ Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠΈ SQL."
|
| 266 |
)
|
| 267 |
yaml_text = st.text_area(
|
|
|
|
| 270 |
height=480,
|
| 271 |
label_visibility="collapsed",
|
| 272 |
)
|
| 273 |
+
c1, c2 = st.columns(2)
|
| 274 |
+
with c1:
|
| 275 |
+
if st.button("ΠΡΠΈΠΌΠ΅Π½ΠΈΡΡ", type="primary", width='stretch'):
|
| 276 |
try:
|
| 277 |
+
st.session_state.vocabulary = _load_vocab_from_yaml(yaml_text)
|
|
|
|
| 278 |
st.session_state.vocab_yaml = yaml_text
|
| 279 |
st.rerun()
|
| 280 |
except Exception as e:
|
| 281 |
st.error(f"ΠΡΠΈΠ±ΠΊΠ° ΡΠΈΠ½ΡΠ°ΠΊΡΠΈΡΠ° YAML: {e}")
|
| 282 |
+
with c2:
|
| 283 |
+
if st.button("ΠΡΠΌΠ΅Π½Π°", width='stretch'):
|
| 284 |
st.rerun()
|
| 285 |
|
| 286 |
|
| 287 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 288 |
+
# Sidebar
|
| 289 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 290 |
with st.sidebar:
|
| 291 |
|
| 292 |
+
# ββ API ββ
|
| 293 |
+
st.markdown('<p class="sb-label">API</p>', unsafe_allow_html=True)
|
| 294 |
+
health = _api_get_health()
|
| 295 |
+
st.session_state.api_health = health
|
| 296 |
+
if health is None:
|
| 297 |
+
st.markdown('<span class="status-err">API Π½Π΅Π΄ΠΎΡΡΡΠΏΠ΅Π½</span>', unsafe_allow_html=True)
|
| 298 |
+
st.caption(f"ΠΠ΄ΡΠ΅Ρ: {API_URL}")
|
| 299 |
+
st.caption("ΠΠ°ΠΏΡΡΡΠΈ Π² ΠΎΡΠ΄Π΅Π»ΡΠ½ΠΎΠΉ ΠΊΠΎΠ½ΡΠΎΠ»ΠΈ: `uvicorn src.api.main:app --reload`")
|
| 300 |
+
if st.session_state.api_error:
|
| 301 |
+
st.caption(f"ΠΡΠΈΡΠΈΠ½Π°: {st.session_state.api_error[:160]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
else:
|
| 303 |
+
if health.get("model_loaded"):
|
| 304 |
+
st.markdown(
|
| 305 |
+
f'<span class="status-ok">β
{health.get("base_model", "ΠΌΠΎΠ΄Π΅Π»Ρ")}</span>',
|
| 306 |
+
unsafe_allow_html=True,
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
st.markdown(
|
| 310 |
+
'<span class="status-warn">β³ ΠΠΎΠ΄Π΅Π»Ρ Π΅ΡΡ Π·Π°Π³ΡΡΠΆΠ°Π΅ΡΡΡ</span>',
|
| 311 |
+
unsafe_allow_html=True,
|
| 312 |
+
)
|
| 313 |
+
st.caption("ΠΠΎΠ΄ΠΎΠΆΠ΄ΠΈΡΠ΅ Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΎ ΠΌΠΈΠ½ΡΡ β ΠΌΠΎΠ΄Π΅Π»Ρ Π΅ΡΡ ΠΈΠ½ΠΈΡΠΈΠ°Π»ΠΈΠ·ΠΈΡΡΠ΅ΡΡΡ.")
|
| 314 |
|
| 315 |
st.markdown('<hr class="sb-divider">', unsafe_allow_html=True)
|
| 316 |
|
| 317 |
# ββ ΠΠ°Π·Π° Π΄Π°Π½Π½ΡΡ
ββ
|
| 318 |
st.markdown('<p class="sb-label">ΠΠ°Π·Π° Π΄Π°Π½Π½ΡΡ
</p>', unsafe_allow_html=True)
|
| 319 |
|
| 320 |
+
modes = ["ΠΠ΅ΠΌΠΎ-Π±Π°Π·Π°", "ΠΠ°Π³ΡΡΠ·ΠΈΡΡ ΡΠ°ΠΉΠ»", "Π‘ΡΡΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ"]
|
| 321 |
+
prev = st.session_state.db_mode
|
| 322 |
db_mode = st.radio(
|
| 323 |
+
"ΠΡΡΠΎΡΠ½ΠΈΠΊ Π΄Π°Π½Π½ΡΡ
", modes,
|
| 324 |
+
index=modes.index(prev) if prev in modes else None,
|
|
|
|
| 325 |
label_visibility="collapsed",
|
| 326 |
)
|
| 327 |
+
if db_mode != prev:
|
| 328 |
+
st.session_state.schema_tables = None
|
| 329 |
+
st.session_state.connection_string = ""
|
|
|
|
| 330 |
st.session_state.db_mode = db_mode
|
| 331 |
|
| 332 |
cs = ""
|
|
|
|
| 333 |
if db_mode == "ΠΠ΅ΠΌΠΎ-Π±Π°Π·Π°":
|
| 334 |
st.caption("ΠΡΡΡΠΎΠ΅Π½Π½Π°Ρ Π±Π°Π·Π°: ΠΈΠ½ΡΠ΅ΡΠ½Π΅Ρ-ΠΌΠ°Π³Π°Π·ΠΈΠ½ ΡΠ»Π΅ΠΊΡΡΠΎΠ½ΠΈΠΊΠΈ, 120 Π·Π°ΠΊΠ°Π·ΠΎΠ².")
|
| 335 |
+
cs = str(ROOT / "data" / "demo" / "sales.sqlite")
|
|
|
|
|
|
|
| 336 |
elif db_mode == "ΠΠ°Π³ΡΡΠ·ΠΈΡΡ ΡΠ°ΠΉΠ»":
|
| 337 |
uploaded = st.file_uploader(
|
| 338 |
+
"SQLite-ΡΠ°ΠΉΠ» Π±Π°Π·Ρ Π΄Π°Π½Π½ΡΡ
", type=["sqlite", "db"],
|
|
|
|
| 339 |
label_visibility="collapsed",
|
| 340 |
)
|
| 341 |
if uploaded:
|
|
|
|
| 345 |
cs = str(tmp_db)
|
| 346 |
else:
|
| 347 |
st.caption("ΠΠ΅ΡΠ΅ΡΠ°ΡΠΈΡΠ΅ .sqlite ΠΈΠ»ΠΈ .db ΡΠ°ΠΉΠ» ΡΡΠ΄Π°")
|
| 348 |
+
else:
|
|
|
|
| 349 |
cs = st.text_input(
|
| 350 |
"Π‘ΡΡΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ",
|
| 351 |
placeholder="postgresql://user:pass@host:5432/db",
|
| 352 |
+
value=st.session_state.connection_string,
|
| 353 |
label_visibility="collapsed",
|
| 354 |
)
|
| 355 |
+
st.caption("PostgreSQL Β· MySQL Β· SQLite (sqlite:///path)")
|
| 356 |
+
|
| 357 |
+
if cs and st.button("ΠΠΎΠ΄ΠΊΠ»ΡΡΠΈΡΡΡΡ", width='stretch', type="primary"):
|
| 358 |
+
with st.spinner("Π§ΡΠ΅Π½ΠΈΠ΅ ΡΡ
Π΅ΠΌΡβ¦"):
|
| 359 |
+
tables, err = _api_get_schema(cs)
|
| 360 |
+
if err:
|
| 361 |
+
st.error(f"ΠΡΠΈΠ±ΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΡ: {err}")
|
| 362 |
+
st.session_state.schema_tables = None
|
| 363 |
+
else:
|
| 364 |
+
st.session_state.schema_tables = tables
|
| 365 |
+
st.session_state.connection_string = cs
|
| 366 |
+
st.session_state.schema_error = None
|
| 367 |
+
if not st.session_state.get("warmup_done", False):
|
| 368 |
+
with st.spinner("ΠΡΠΎΠ³ΡΠ΅Π² ΠΌΠΎΠ΄Π΅Π»ΠΈ (Π·Π°ΠΏΡΡΠΊΠ°Π΅ΡΡΡ ΠΎΠ΄ΠΈΠ½ ΡΠ°Π· Π·Π° ΡΠ΅ΡΡΠΈΡ)β¦"):
|
| 369 |
+
ok, _err = _api_warmup()
|
| 370 |
+
if ok:
|
| 371 |
+
st.session_state.warmup_done = True
|
| 372 |
+
# ΠΠ²ΡΠΎΠ·Π°Π³ΡΡΠ·ΠΊΠ° ΡΠ»ΠΎΠ²Π°ΡΡ Π΄Π»Ρ Π΄Π΅ΠΌΠΎ-Π±Π°Π·Ρ
|
| 373 |
if "sales" in cs and st.session_state.vocabulary is None:
|
| 374 |
+
vp = ROOT / "configs" / "example_vocabulary.yaml"
|
| 375 |
+
if vp.exists():
|
| 376 |
+
try:
|
| 377 |
+
st.session_state.vocabulary = _load_vocab_from_yaml(
|
| 378 |
+
vp.read_text(encoding="utf-8")
|
| 379 |
+
)
|
| 380 |
+
except Exception:
|
| 381 |
+
pass
|
| 382 |
st.success(f"ΠΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΎ. Π’Π°Π±Π»ΠΈΡ: {len(tables)}")
|
|
|
|
|
|
|
| 383 |
|
| 384 |
+
if st.session_state.schema_tables is not None:
|
| 385 |
+
n = len(st.session_state.schema_tables)
|
| 386 |
st.markdown(
|
| 387 |
'<span class="status-ok">β
ΠΠ°Π·Π° Π΄Π°Π½Π½ΡΡ
ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½Π°</span>',
|
| 388 |
unsafe_allow_html=True,
|
| 389 |
)
|
| 390 |
+
with st.expander(f"Π’Π°Π±Π»ΠΈΡΡ ({n})"):
|
| 391 |
+
for t in st.session_state.schema_tables:
|
| 392 |
+
st.code(t.get("name", ""), language=None)
|
| 393 |
|
| 394 |
st.markdown('<hr class="sb-divider">', unsafe_allow_html=True)
|
| 395 |
|
| 396 |
# ββ ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ ββ
|
| 397 |
st.markdown('<p class="sb-label">ΠΠΈΠ·Π½Π΅Ρ-ΡΠ»ΠΎΠ²Π°ΡΡ</p>', unsafe_allow_html=True)
|
|
|
|
| 398 |
if st.session_state.vocabulary:
|
| 399 |
v = st.session_state.vocabulary
|
| 400 |
label = v.company if v.company else "ΠΠ°Π³ΡΡΠΆΠ΅Π½"
|
| 401 |
+
st.markdown(f'<span class="status-ok">β
{label}</span>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
| 402 |
if v.terms:
|
| 403 |
+
st.caption(f"Π’Π΅ΡΠΌΠΈΠ½ΠΎΠ²: {len(v.terms)}")
|
|
|
|
|
|
|
|
|
|
| 404 |
else:
|
| 405 |
+
st.caption("Π‘Π»ΠΎΠ²Π°ΡΡ Π½Π΅ ΠΏΡΠΈΠΌΠ΅Π½ΡΠ½")
|
| 406 |
+
if st.button("Π Π΅Π΄Π°ΠΊΡΠΈΡΠΎΠ²Π°ΡΡ ΡΠ»ΠΎΠ²Π°ΡΡ", width='stretch'):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
vocab_dialog()
|
| 408 |
|
| 409 |
|
|
|
|
|
|
|
| 410 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 411 |
+
# Π¨Π°ΠΏΠΊΠ°
|
| 412 |
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 413 |
st.markdown("""
|
| 414 |
<div class="app-header">
|
| 415 |
<p class="app-title">Ru2SQL β Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠ²Π½Π°Ρ ΠΌΠΎΠ΄Π΅Π»Ρ ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΡ Π·Π°ΠΏΡΠΎΡΠΎΠ²<br>
|
| 416 |
ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½ΡΡ
Π½Π° ΡΡΡΡΠΊΠΎΠΌ ΡΠ·ΡΠΊΠ΅ Π² Π·Π°ΠΏΡΠΎΡΡ Π½Π° ΡΠ·ΡΠΊΠ΅ SQL</p>
|
| 417 |
<p class="app-subtitle">
|
| 418 |
+
Qwen2.5-Coder-3B-Instruct Β· QLoRA Π½Π° PAUQ
|
| 419 |
Β· SQLite / PostgreSQL / MySQL
|
| 420 |
</p>
|
| 421 |
</div>
|
|
|
|
| 423 |
|
| 424 |
tab_query, tab_schema, tab_history = st.tabs(["ΠΠ°ΠΏΡΠΎΡ", "Π‘Ρ
Π΅ΠΌΠ° Π±Π°Π·Ρ Π΄Π°Π½Π½ΡΡ
", "ΠΡΡΠΎΡΠΈΡ"])
|
| 425 |
|
| 426 |
+
|
| 427 |
+
# ββββββββββββ Tab: ΠΠ°ΠΏΡΠΎΡ ββββββββββββ
|
| 428 |
with tab_query:
|
| 429 |
+
api_ready = (
|
| 430 |
+
st.session_state.api_health is not None
|
| 431 |
+
and st.session_state.api_health.get("model_loaded", False)
|
| 432 |
+
)
|
| 433 |
+
db_ready = st.session_state.schema_tables is not None
|
| 434 |
+
ready = api_ready and db_ready
|
| 435 |
|
| 436 |
if not ready:
|
| 437 |
missing = []
|
| 438 |
+
if not api_ready:
|
| 439 |
+
missing.append("API/ΠΌΠΎΠ΄Π΅Π»Ρ Π½Π΅ Π³ΠΎΡΠΎΠ²Ρ")
|
| 440 |
+
if not db_ready:
|
| 441 |
missing.append("Π±Π°Π·Π° Π΄Π°Π½Π½ΡΡ
Π½Π΅ ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½Π°")
|
| 442 |
st.warning("Π‘ΠΈΡΡΠ΅ΠΌΠ° Π½Π΅ Π³ΠΎΡΠΎΠ²Π°: " + ", ".join(missing) + ". ΠΡΠΏΠΎΠ»ΡΠ·ΡΠΉΡΠ΅ ΠΏΠ°Π½Π΅Π»Ρ ΡΠ»Π΅Π²Π°.")
|
| 443 |
|
|
|
|
| 448 |
disabled=not ready,
|
| 449 |
)
|
| 450 |
|
| 451 |
+
col_btn, _ = st.columns([1, 5])
|
| 452 |
with col_btn:
|
| 453 |
run_btn = st.button(
|
| 454 |
"ΠΡΠΏΠΎΠ»Π½ΠΈΡΡ",
|
| 455 |
type="primary",
|
| 456 |
disabled=not ready or not question.strip(),
|
| 457 |
+
width='stretch',
|
| 458 |
)
|
| 459 |
|
| 460 |
# ΠΡΠΈΠΌΠ΅ΡΡ Π΄Π»Ρ Π΄Π΅ΠΌΠΎ-Π±Π°Π·Ρ
|
| 461 |
if (
|
| 462 |
+
st.session_state.connection_string
|
| 463 |
+
and "sales" in st.session_state.connection_string
|
| 464 |
):
|
| 465 |
st.markdown('<p class="examples-label">ΠΡΠΈΠΌΠ΅ΡΡ Π·Π°ΠΏΡΠΎΡΠΎΠ²</p>', unsafe_allow_html=True)
|
| 466 |
ex_cols = st.columns(3)
|
|
|
|
| 471 |
]
|
| 472 |
for i, ex in enumerate(examples):
|
| 473 |
with ex_cols[i]:
|
| 474 |
+
if st.button(ex, key=f"ex_{i}", width='stretch'):
|
| 475 |
question = ex
|
| 476 |
+
run_btn = True
|
| 477 |
|
| 478 |
if run_btn and question.strip():
|
| 479 |
+
cs = st.session_state.connection_string
|
| 480 |
+
vocab = st.session_state.vocabulary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
+
with st.spinner("ΠΠ°ΠΏΡΠΎΡ ΠΊ API. ΠΡΠΎ ΠΌΠΎΠΆΠ΅Ρ Π·Π°Π½ΡΡΡ Π½Π΅ΡΠΊΠΎΠ»ΡΠΊΠΎ ΠΌΠΈΠ½ΡΡ"):
|
| 483 |
+
try:
|
| 484 |
+
resp = _api_query(question, cs, vocab)
|
| 485 |
+
except Exception as e:
|
| 486 |
+
st.error(f"ΠΡΠΈΠ±ΠΊΠ°: {e}")
|
| 487 |
+
st.stop()
|
| 488 |
|
| 489 |
st.markdown("**Π‘Π³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°Π½Π½ΡΠΉ SQL**")
|
| 490 |
+
st.markdown(f'<div class="sql-box">{resp.get("sql", "")}</div>', unsafe_allow_html=True)
|
| 491 |
|
| 492 |
+
gen_time = resp.get("gen_time_seconds", 0.0)
|
| 493 |
+
execution = resp.get("execution")
|
| 494 |
+
err = resp.get("error")
|
|
|
|
| 495 |
|
| 496 |
c1, c2, c3 = st.columns(3)
|
| 497 |
c1.metric("ΠΡΠ΅ΠΌΡ Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠΈ", f"{gen_time:.1f} Ρ")
|
| 498 |
+
if execution:
|
| 499 |
+
c2.metric("Π‘ΡΡΠΎΠΊ ΠΏΠΎΠ»ΡΡΠ΅Π½ΠΎ", execution.get("row_count", 0))
|
| 500 |
+
c3.metric("Π‘ΡΠ°ΡΡΡ", "Π£ΡΠΏΠ΅ΡΠ½ΠΎ")
|
| 501 |
+
elif err:
|
| 502 |
+
c2.metric("Π‘ΡΡΠΎΠΊ ΠΏΠΎΠ»ΡΡΠ΅Π½ΠΎ", "β")
|
| 503 |
+
c3.metric("Π‘ΡΠ°ΡΡΡ", "ΠΡΠΈΠ±ΠΊΠ°")
|
| 504 |
+
|
| 505 |
+
if execution and execution.get("rows"):
|
| 506 |
+
import pandas as pd
|
| 507 |
+
st.markdown("**Π Π΅Π·ΡΠ»ΡΡΠ°Ρ**")
|
| 508 |
+
df = pd.DataFrame(execution["rows"], columns=execution["columns"])
|
| 509 |
+
st.dataframe(df, width='stretch')
|
| 510 |
+
elif execution and not execution.get("rows"):
|
| 511 |
+
st.info("ΠΠ°ΠΏΡΠΎΡ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ ΡΡΠΏΠ΅ΡΠ½ΠΎ. Π Π΅Π·ΡΠ»ΡΡΠ°Ρ ΠΏΡΡΡΠΎΠΉ.")
|
| 512 |
+
elif err:
|
| 513 |
+
st.error(f"ΠΡΠΈΠ±ΠΊΠ° Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΡ SQL: {err}")
|
| 514 |
|
| 515 |
st.session_state.history.append({
|
| 516 |
"question": question,
|
| 517 |
+
"sql": resp.get("sql", ""),
|
| 518 |
+
"success": bool(execution),
|
| 519 |
+
"rows": execution.get("row_count", 0) if execution else 0,
|
| 520 |
"time": gen_time,
|
| 521 |
})
|
| 522 |
|
| 523 |
+
|
| 524 |
+
# ββββββββββββ Tab: Π‘Ρ
Π΅ΠΌΠ° ΠΠ ββββββββββββ
|
| 525 |
with tab_schema:
|
| 526 |
+
if st.session_state.schema_tables is None:
|
| 527 |
st.info("ΠΠΎΠ΄ΠΊΠ»ΡΡΠΈΡΠ΅ΡΡ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½ΡΡ
ΡΠ΅ΡΠ΅Π· ΠΏΠ°Π½Π΅Π»Ρ ΡΠ»Π΅Π²Π°.")
|
| 528 |
else:
|
|
|
|
| 529 |
show_samples = st.toggle("ΠΠΎΠΊΠ°Π·ΡΠ²Π°ΡΡ ΠΏΡΠΈΠΌΠ΅ΡΡ Π΄Π°Π½Π½ΡΡ
", value=True)
|
| 530 |
+
for t in st.session_state.schema_tables:
|
| 531 |
+
with st.expander(f"{t['name']} β {len(t['columns'])} ΠΊΠΎΠ»ΠΎΠ½ΠΎΠΊ"):
|
| 532 |
+
st.code(t.get("ddl", ""), language="sql")
|
| 533 |
+
if show_samples and t.get("sample_rows"):
|
|
|
|
| 534 |
import pandas as pd
|
| 535 |
+
cols = [c["name"] for c in t["columns"]]
|
| 536 |
st.caption("ΠΡΠΈΠΌΠ΅ΡΡ Π΄Π°Π½Π½ΡΡ
:")
|
| 537 |
st.dataframe(
|
| 538 |
+
pd.DataFrame(t["sample_rows"], columns=cols),
|
| 539 |
+
width='stretch',
|
| 540 |
)
|
| 541 |
|
| 542 |
+
|
| 543 |
+
# ββββββββββββ Tab: ΠΡΡΠΎΡΠΈΡ ββββββββββββ
|
| 544 |
with tab_history:
|
| 545 |
history = st.session_state.history
|
| 546 |
if not history:
|
|
|
|
| 550 |
with col_h:
|
| 551 |
st.markdown(f"**ΠΠ°ΠΏΡΠΎΡΠΎΠ² Π² ΡΠ΅ΡΡΠΈΠΈ: {len(history)}**")
|
| 552 |
with col_clr:
|
| 553 |
+
if st.button("ΠΡΠΈΡΡΠΈΡΡ", width='stretch'):
|
| 554 |
st.session_state.history = []
|
| 555 |
st.rerun()
|
| 556 |
|
tests/test_db.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Π’Π΅ΡΡΡ Π½Π° DbConnector ΠΈ SqlExecutor.
|
| 2 |
+
|
| 3 |
+
ΠΠΎΠΊΡΡΠ²Π°ΡΡ ΡΡΠ΅Π½ΠΈΠ΅ ΡΡ
Π΅ΠΌΡ SQLite-Π±Π°Π·, Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ DDL ΠΈ ΠΏΡΠΎΠ²Π΅ΡΠΊΡ ΡΠΎΠ³ΠΎ, ΡΡΠΎ
|
| 4 |
+
SQLite-ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ Π΄Π΅ΠΉΡΡΠ²ΠΈΡΠ΅Π»ΡΠ½ΠΎ ΠΎΡΠΊΡΡΠ²Π°Π΅ΡΡΡ Π² ΡΠ΅ΠΆΠΈΠΌΠ΅ read-only β
|
| 5 |
+
ΠΌΠΎΠ΄ΠΈΡΠΈΡΠΈΡΡΡΡΠΈΠ΅ ΠΎΠΏΠ΅ΡΠ°ΡΠΈΠΈ Π΄ΠΎΠ»ΠΆΠ½Ρ ΠΏΠ°Π΄Π°ΡΡ Ρ sqlite3.OperationalError.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sqlite3
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import pytest
|
| 12 |
+
|
| 13 |
+
from src.db.connector import DbConnector, TableInfo
|
| 14 |
+
from src.db.executor import QueryResult, SqlExecutor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def tiny_sqlite(tmp_path: Path) -> Path:
|
| 19 |
+
db = tmp_path / "tiny.sqlite"
|
| 20 |
+
conn = sqlite3.connect(db)
|
| 21 |
+
conn.execute(
|
| 22 |
+
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, city TEXT)"
|
| 23 |
+
)
|
| 24 |
+
conn.executemany(
|
| 25 |
+
"INSERT INTO users (id, name, city) VALUES (?, ?, ?)",
|
| 26 |
+
[(1, "ΠΠ²Π°Π½", "ΠΠ°Π·Π°Π½Ρ"), (2, "ΠΠ½Π½Π°", "ΠΠΎΡΠΊΠ²Π°"), (3, "ΠΠ»Π΅Π³", "ΠΠ°Π·Π°Π½Ρ")],
|
| 27 |
+
)
|
| 28 |
+
conn.commit()
|
| 29 |
+
conn.close()
|
| 30 |
+
return db
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
# DbConnector
|
| 35 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
def test_connector_lists_tables(tiny_sqlite: Path):
|
| 38 |
+
c = DbConnector(str(tiny_sqlite))
|
| 39 |
+
assert c.list_tables() == ["users"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_connector_reads_columns(tiny_sqlite: Path):
|
| 43 |
+
c = DbConnector(str(tiny_sqlite))
|
| 44 |
+
tables = c.get_schema(include_samples=False)
|
| 45 |
+
assert len(tables) == 1
|
| 46 |
+
table = tables[0]
|
| 47 |
+
assert isinstance(table, TableInfo)
|
| 48 |
+
names = [col.name for col in table.columns]
|
| 49 |
+
assert names == ["id", "name", "city"]
|
| 50 |
+
# id β primary key, name β NOT NULL
|
| 51 |
+
pk = next(col for col in table.columns if col.name == "id")
|
| 52 |
+
assert pk.primary_key is True
|
| 53 |
+
nn = next(col for col in table.columns if col.name == "name")
|
| 54 |
+
assert nn.nullable is False
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_connector_renders_ddl(tiny_sqlite: Path):
|
| 58 |
+
c = DbConnector(str(tiny_sqlite))
|
| 59 |
+
schema_text = c.render_schema(include_samples=True)
|
| 60 |
+
assert "CREATE TABLE users" in schema_text
|
| 61 |
+
assert "PRIMARY KEY" in schema_text
|
| 62 |
+
# sample-ΡΡΡΠΎΠΊΠΈ ΠΏΡΠΎΠΊΠΈΠ½ΡΡΡ ΠΊΠΎΠΌΠΌΠ΅Π½ΡΠ°ΡΠΈΡΠΌΠΈ
|
| 63 |
+
assert "ΠΠ²Π°Π½" in schema_text or "ΠΠ»Π΅Π³" in schema_text
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_connector_accepts_sqlite_uri(tiny_sqlite: Path):
|
| 67 |
+
c = DbConnector(f"sqlite:///{tiny_sqlite}")
|
| 68 |
+
assert c.list_tables() == ["users"]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
+
# SqlExecutor
|
| 73 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
|
| 75 |
+
def test_executor_runs_select(tiny_sqlite: Path):
|
| 76 |
+
ex = SqlExecutor(str(tiny_sqlite))
|
| 77 |
+
res = ex.run("SELECT id, name FROM users ORDER BY id")
|
| 78 |
+
assert isinstance(res, QueryResult)
|
| 79 |
+
assert res.success
|
| 80 |
+
assert res.columns == ["id", "name"]
|
| 81 |
+
assert res.row_count == 3
|
| 82 |
+
assert res.rows[0] == [1, "ΠΠ²Π°Π½"]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_executor_aggregation(tiny_sqlite: Path):
|
| 86 |
+
ex = SqlExecutor(str(tiny_sqlite))
|
| 87 |
+
res = ex.run("SELECT city, COUNT(*) AS cnt FROM users GROUP BY city ORDER BY cnt DESC")
|
| 88 |
+
assert res.success
|
| 89 |
+
assert res.rows[0] == ["ΠΠ°Π·Π°Π½Ρ", 2]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def test_executor_returns_error_on_bad_sql(tiny_sqlite: Path):
|
| 93 |
+
ex = SqlExecutor(str(tiny_sqlite))
|
| 94 |
+
res = ex.run("SELEC nonsense FROM users")
|
| 95 |
+
assert not res.success
|
| 96 |
+
assert res.error is not None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def test_executor_blocks_modifications(tiny_sqlite: Path):
|
| 100 |
+
"""ΠΠ»ΡΡΠ΅Π²Π°Ρ ΠΏΡΠΎΠ²Π΅ΡΠΊΠ°: SQLite-ΡΠΎΠ΅Π΄ΠΈΠ½Π΅Π½ΠΈΠ΅ ΠΎΡΠΊΡΡΠ²Π°Π΅ΡΡΡ Π² read-only
|
| 101 |
+
ΡΠ΅ΠΆΠΈΠΌΠ΅ (URI mode=ro&immutable=1), ΠΌΠΎΠ΄ΠΈΡΠΈΡΠΈΡΡΡΡΠΈΠ΅ ΠΎΠΏΠ΅ΡΠ°ΡΠΈΠΈ Π΄ΠΎΠ»ΠΆΠ½Ρ
|
| 102 |
+
ΠΏΠ°Π΄Π°ΡΡ ΠΎΡΠΈΠ±ΠΊΠΎΠΉ, Π° Π½Π΅ Π²ΡΠΏΠΎΠ»Π½ΡΡΡΡΡ Π²ΡΠΈΡ
ΡΡ."""
|
| 103 |
+
ex = SqlExecutor(str(tiny_sqlite))
|
| 104 |
+
|
| 105 |
+
res = ex.run("DELETE FROM users WHERE id = 1")
|
| 106 |
+
assert not res.success
|
| 107 |
+
assert res.error is not None
|
| 108 |
+
assert "read" in res.error.lower() or "readonly" in res.error.lower() \
|
| 109 |
+
or "ΡΠΎΠ»ΡΠΊΠΎ Π΄Π»Ρ ΡΡΠ΅Π½ΠΈΡ" in res.error.lower()
|
| 110 |
+
|
| 111 |
+
# ΠΠΎΠ΄ΡΠ²Π΅ΡΠΆΠ΄Π΅Π½ΠΈΠ΅, ΡΡΠΎ Π΄Π°Π½Π½ΡΠ΅ Π½Π΅ ΠΏΠΎΡΡΡΠ°Π΄Π°Π»ΠΈ
|
| 112 |
+
check = ex.run("SELECT COUNT(*) FROM users")
|
| 113 |
+
assert check.success
|
| 114 |
+
assert check.rows == [[3]]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def test_executor_blocks_drop_table(tiny_sqlite: Path):
|
| 118 |
+
ex = SqlExecutor(str(tiny_sqlite))
|
| 119 |
+
res = ex.run("DROP TABLE users")
|
| 120 |
+
assert not res.success
|
| 121 |
+
|
| 122 |
+
# ΠΠΎΠ΄ΡΠ²Π΅ΡΠΆΠ΄Π΅Π½ΠΈΠ΅, ΡΡΠΎ ΡΠ°Π±Π»ΠΈΡΠ° Π½Π° ΠΌΠ΅ΡΡΠ΅
|
| 123 |
+
check = ex.run("SELECT COUNT(*) FROM users")
|
| 124 |
+
assert check.success
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_queryresult_to_markdown(tiny_sqlite: Path):
|
| 128 |
+
ex = SqlExecutor(str(tiny_sqlite))
|
| 129 |
+
res = ex.run("SELECT id, name FROM users WHERE id = 1")
|
| 130 |
+
md = res.to_markdown_table()
|
| 131 |
+
assert "id" in md and "name" in md
|
| 132 |
+
assert "ΠΠ²Π°Π½" in md
|
tests/test_postprocess.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
| 1 |
-
"""Π’Π΅ΡΡΡ Π½Π° ΠΏΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΡ SQL.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from src.models.postprocess import (
|
|
|
|
| 4 |
is_valid_sql,
|
| 5 |
normalize_sql,
|
| 6 |
postprocess,
|
|
@@ -8,39 +14,165 @@ from src.models.postprocess import (
|
|
| 8 |
)
|
| 9 |
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
raw = "```sql\nSELECT * FROM users;\n```"
|
| 13 |
-
assert strip_model_artifacts(raw).startswith("SELECT")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def test_strip_sql_prefix():
|
| 17 |
raw = "SQL: SELECT 1;"
|
| 18 |
-
assert strip_model_artifacts(raw).startswith("SELECT")
|
| 19 |
|
| 20 |
|
| 21 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
raw = "SELECT 1; SELECT 2;"
|
| 23 |
out = strip_model_artifacts(raw)
|
| 24 |
assert "SELECT 1" in out
|
| 25 |
assert "SELECT 2" not in out
|
| 26 |
|
| 27 |
|
| 28 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
assert is_valid_sql("SELECT * FROM students WHERE id = 1")
|
| 30 |
|
| 31 |
|
| 32 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
assert not is_valid_sql("SELEC * FRM where")
|
| 34 |
|
| 35 |
|
| 36 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
a = "SELECT * FROM Users"
|
| 38 |
b = "select * from users"
|
| 39 |
assert normalize_sql(a) == normalize_sql(b)
|
| 40 |
|
| 41 |
|
| 42 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
raw = "```sql\nSELECT name FROM students WHERE group_id = 1;\nSELECT 2;\n```"
|
| 44 |
out = postprocess(raw)
|
| 45 |
-
assert out.startswith("SELECT name")
|
| 46 |
assert "SELECT 2" not in out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Π’Π΅ΡΡΡ Π½Π° ΠΏΠΎΡΡΠΎΠ±ΡΠ°Π±ΠΎΡΠΊΡ SQL ΠΈ ΡΠ²ΡΠ·Π°Π½Π½ΡΠ΅ ΡΡΠ½ΠΊΡΠΈΠΈ.
|
| 2 |
+
|
| 3 |
+
ΠΠΎΠΊΡΡΠ²Π°Π΅Ρ ΡΠ°Π·Π΄Π΅Π» 2.5 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ: ΡΠΈΡΡΠΊΡ Π°ΡΡΠ΅ΡΠ°ΠΊΡΠΎΠ²,
|
| 4 |
+
Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΡ ΡΠ΅ΡΠ΅Π· sqlglot, Π½ΠΎΡΠΌΠ°Π»ΠΈΠ·Π°ΡΠΈΡ Π΄Π»Ρ Exact Match ΠΈ AST-ΡΡΠΎΠ²Π½Π΅Π²ΡΠΉ
|
| 5 |
+
Π³Π²Π°ΡΠ΄Π΅ΠΉΠ» is_select_only.
|
| 6 |
+
"""
|
| 7 |
|
| 8 |
from src.models.postprocess import (
|
| 9 |
+
is_select_only,
|
| 10 |
is_valid_sql,
|
| 11 |
normalize_sql,
|
| 12 |
postprocess,
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
|
| 17 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
# strip_model_artifacts
|
| 19 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
|
| 21 |
+
def test_strip_markdown_block_with_lang():
|
| 22 |
raw = "```sql\nSELECT * FROM users;\n```"
|
| 23 |
+
assert strip_model_artifacts(raw).upper().startswith("SELECT")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_strip_markdown_block_without_lang():
|
| 27 |
+
raw = "```\nSELECT id FROM t;\n```"
|
| 28 |
+
assert strip_model_artifacts(raw).upper().startswith("SELECT")
|
| 29 |
|
| 30 |
|
| 31 |
def test_strip_sql_prefix():
|
| 32 |
raw = "SQL: SELECT 1;"
|
| 33 |
+
assert strip_model_artifacts(raw).upper().startswith("SELECT")
|
| 34 |
|
| 35 |
|
| 36 |
+
def test_strip_russian_prefix():
|
| 37 |
+
raw = "ΠΡΠ²Π΅Ρ: SELECT name FROM students;"
|
| 38 |
+
assert strip_model_artifacts(raw).upper().startswith("SELECT")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_strip_natural_language_before_select():
|
| 42 |
+
raw = "ΠΠΎΡ SQL, ΠΊΠΎΡΠΎΡΡΠΉ ΠΎΡΠ²Π΅ΡΠ°Π΅Ρ Π½Π° Π²ΠΎΠΏΡΠΎΡ: SELECT * FROM t WHERE id = 1;"
|
| 43 |
+
out = strip_model_artifacts(raw)
|
| 44 |
+
assert out.upper().startswith("SELECT")
|
| 45 |
+
assert "ΠΠΎΡ" not in out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_keeps_first_statement_of_two():
|
| 49 |
raw = "SELECT 1; SELECT 2;"
|
| 50 |
out = strip_model_artifacts(raw)
|
| 51 |
assert "SELECT 1" in out
|
| 52 |
assert "SELECT 2" not in out
|
| 53 |
|
| 54 |
|
| 55 |
+
def test_with_cte_is_preserved():
|
| 56 |
+
raw = "WITH agg AS (SELECT id FROM t) SELECT * FROM agg"
|
| 57 |
+
out = strip_model_artifacts(raw)
|
| 58 |
+
assert out.upper().startswith("WITH")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_strip_returns_empty_on_garbage():
|
| 62 |
+
# ΠΠ΅Ρ Π½ΠΈ ΠΎΠ΄Π½ΠΎΠ³ΠΎ SQL-ΠΊΠ»ΡΡΠ΅Π²ΠΎΠ³ΠΎ ΡΠ»ΠΎΠ²Π° β ΠΎΠ±ΡΠ΅Π·Π°ΡΡ Π½Π΅ΡΠ΅Π³ΠΎ, Π½ΠΎ ΠΈ ΠΏΡΡΡΠΎΠ³ΠΎ
|
| 63 |
+
# ΠΎΡΠ²Π΅ΡΠ° ΠΌΠΎΠ΄Π΅Π»Ρ Π΅ΡΡ Π½Π΅ Π½Π°Π³Π΅Π½Π΅ΡΠΈΠ»Π°: Π²ΠΎΠ·Π²ΡΠ°ΡΠ°Π΅ΠΌ ΠΊΠ°ΠΊ Π΅ΡΡΡ, Π²Π°Π»ΠΈΠ΄Π°ΡΠΈΡ
|
| 64 |
+
# ΠΎΡΡΠ΅Π΅Ρ Π΄Π°Π»ΡΡΠ΅ ΠΏΠΎ ΠΏΠ°ΠΉΠΏΠ»Π°ΠΉΠ½Ρ.
|
| 65 |
+
raw = "ΠΏΡΠΎΡΡΠΎ ΡΠ΅ΠΊΡΡ Π±Π΅Π· Π·Π°ΠΏΡΠΎΡΠ°"
|
| 66 |
+
assert strip_model_artifacts(raw) == "ΠΏΡΠΎΡΡΠΎ ΡΠ΅ΠΊΡΡ Π±Π΅Π· Π·Π°ΠΏΡΠΎΡΠ°"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
# is_valid_sql
|
| 71 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
+
|
| 73 |
+
def test_valid_select():
|
| 74 |
assert is_valid_sql("SELECT * FROM students WHERE id = 1")
|
| 75 |
|
| 76 |
|
| 77 |
+
def test_valid_with_cte():
|
| 78 |
+
assert is_valid_sql("WITH x AS (SELECT id FROM t) SELECT * FROM x")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def test_invalid_garbage():
|
| 82 |
assert not is_valid_sql("SELEC * FRM where")
|
| 83 |
|
| 84 |
|
| 85 |
+
def test_invalid_empty():
|
| 86 |
+
assert not is_valid_sql("")
|
| 87 |
+
assert not is_valid_sql(" ")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 91 |
+
# is_select_only β guardrail
|
| 92 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 93 |
+
|
| 94 |
+
def test_select_passes_guardrail():
|
| 95 |
+
assert is_select_only("SELECT id FROM t")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_with_cte_passes_guardrail():
|
| 99 |
+
assert is_select_only("WITH x AS (SELECT id FROM t) SELECT * FROM x")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def test_drop_table_blocked():
|
| 103 |
+
assert not is_select_only("DROP TABLE users")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_delete_blocked():
|
| 107 |
+
assert not is_select_only("DELETE FROM users WHERE id = 1")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_update_blocked():
|
| 111 |
+
assert not is_select_only("UPDATE users SET name = 'a' WHERE id = 1")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def test_insert_blocked():
|
| 115 |
+
assert not is_select_only("INSERT INTO users (id, name) VALUES (1, 'a')")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def test_empty_blocked():
|
| 119 |
+
assert not is_select_only("")
|
| 120 |
+
assert not is_select_only(" ")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def test_invalid_sql_blocked_by_guardrail():
|
| 124 |
+
# ΠΠ° Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΠΎΠΉ ΡΡΡΠΎΠΊΠ΅ is_select_only Π΄ΠΎΠ»ΠΆΠ΅Π½ ΡΠ΅ΡΡΠ½ΠΎ Π²ΠΎΠ·Π²ΡΠ°ΡΠ°ΡΡ False,
|
| 125 |
+
# Π° Π½Π΅ ΠΏΠ°Π΄Π°ΡΡ Ρ ΠΈΡΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ΠΌ.
|
| 126 |
+
assert not is_select_only("not a sql at all")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
+
# normalize_sql
|
| 131 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 132 |
+
|
| 133 |
+
def test_normalize_collapses_whitespace():
|
| 134 |
a = "SELECT * FROM Users"
|
| 135 |
b = "select * from users"
|
| 136 |
assert normalize_sql(a) == normalize_sql(b)
|
| 137 |
|
| 138 |
|
| 139 |
+
def test_normalize_idempotent():
|
| 140 |
+
sql = "SELECT id FROM t WHERE x = 1"
|
| 141 |
+
assert normalize_sql(normalize_sql(sql)) == normalize_sql(sql)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def test_normalize_fallback_on_invalid():
|
| 145 |
+
# ΠΠ° Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΠΎΠΌ SQL ΡΡΠ½ΠΊΡΠΈΡ Π½Π΅ Π΄ΠΎΠ»ΠΆΠ½Π° ΠΏΠ°Π΄Π°ΡΡ β Π΄ΠΎΠ»ΠΆΠ΅Π½ ΡΡΠ°Π±ΠΎΡΠ°ΡΡ fallback.
|
| 146 |
+
out = normalize_sql("not really sql")
|
| 147 |
+
assert isinstance(out, str)
|
| 148 |
+
assert out.upper() == out # Π²Π΅ΡΡ
Π½ΠΈΠΉ ΡΠ΅Π³ΠΈΡΡΡ ΡΠΎΡ
ΡΠ°Π½ΡΠ½
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 152 |
+
# postprocess β ΠΏΠΎΠ»Π½ΡΠΉ pipeline
|
| 153 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
+
|
| 155 |
+
def test_postprocess_extracts_from_markdown():
|
| 156 |
raw = "```sql\nSELECT name FROM students WHERE group_id = 1;\nSELECT 2;\n```"
|
| 157 |
out = postprocess(raw)
|
| 158 |
+
assert out.upper().startswith("SELECT NAME") or out.startswith("SELECT name")
|
| 159 |
assert "SELECT 2" not in out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def test_postprocess_returns_empty_on_invalid():
|
| 163 |
+
# Π’Π΅ΠΊΡΡ Π½Π΅ ΡΠΎΠ΄Π΅ΡΠΆΠΈΡ Π²Π°Π»ΠΈΠ΄Π½ΠΎΠ³ΠΎ SQL β pipeline Π΄ΠΎΠ»ΠΆΠ΅Π½ Π²Π΅ΡΠ½ΡΡΡ ΠΏΡΡΡΡΡ ΡΡΡΠΎΠΊΡ,
|
| 164 |
+
# ΠΊΠ°ΠΊ ΠΎΠΏΠΈΡΠ°Π½ΠΎ Π² ΡΠ°Π·Π΄Π΅Π»Π΅ 2.5 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ.
|
| 165 |
+
raw = "Π― Π½Π΅ ΠΌΠΎΠ³Ρ ΡΠ³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°ΡΡ SQL Π΄Π»Ρ ΡΡΠΎΠ³ΠΎ Π²ΠΎΠΏΡΠΎΡΠ°."
|
| 166 |
+
assert postprocess(raw) == ""
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def test_postprocess_returns_empty_on_truncated():
|
| 170 |
+
# ΠΠΎΠ΄Π΅Π»Ρ ΠΎΠ±ΠΎΡΠ²Π°Π»Π° Π³Π΅Π½Π΅ΡΠ°ΡΠΈΡ Π½Π° ΡΠ΅ΡΠ΅Π΄ΠΈΠ½Π΅ Π·Π°ΠΏΡΠΎΡΠ° β Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΡΠΉ ΡΠΈΠ½ΡΠ°ΠΊΡΠΈΡ.
|
| 171 |
+
raw = "SELECT * FROM users WHERE"
|
| 172 |
+
assert postprocess(raw) == ""
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def test_postprocess_keeps_valid_with_cte():
|
| 176 |
+
raw = "WITH agg AS (SELECT id FROM t) SELECT * FROM agg"
|
| 177 |
+
out = postprocess(raw)
|
| 178 |
+
assert out.upper().startswith("WITH")
|
tests/test_prompt.py
CHANGED
|
@@ -1,8 +1,15 @@
|
|
| 1 |
-
"""Π’Π΅ΡΡΡ Π½Π° PromptBuilder.
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from src.data.prompt import (
|
|
|
|
| 4 |
SYSTEM_PROMPT,
|
| 5 |
build_chat_messages,
|
|
|
|
| 6 |
build_training_example,
|
| 7 |
build_user_message,
|
| 8 |
)
|
|
@@ -21,7 +28,7 @@ def test_chat_messages_have_system_and_user():
|
|
| 21 |
msgs = build_chat_messages("schema", "question")
|
| 22 |
assert len(msgs) == 2
|
| 23 |
assert msgs[0]["role"] == "system"
|
| 24 |
-
assert msgs[0]["content"] ==
|
| 25 |
assert msgs[1]["role"] == "user"
|
| 26 |
|
| 27 |
|
|
@@ -30,3 +37,51 @@ def test_training_example_has_assistant():
|
|
| 30 |
assert len(msgs) == 3
|
| 31 |
assert msgs[2]["role"] == "assistant"
|
| 32 |
assert msgs[2]["content"] == "SELECT 1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Π’Π΅ΡΡΡ Π½Π° PromptBuilder.
|
| 2 |
|
| 3 |
+
ΠΠΎΠΊΡΡΠ²Π°ΡΡ ΠΊΠ°ΠΊ Π±Π°Π·ΠΎΠ²ΠΎΠ΅ ΡΠΎΡΠΌΠΈΡΠΎΠ²Π°Π½ΠΈΠ΅ chat-template, ΡΠ°ΠΊ ΠΈ ΠΎΠΏΡΠΈΠΎΠ½Π°Π»ΡΠ½ΡΡ
|
| 4 |
+
ΠΈΠ½ΡΠ΅Π³ΡΠ°ΡΠΈΡ BusinessVocabulary Π² ΡΠΈΡΡΠ΅ΠΌΠ½ΠΎΠ΅ ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠ΅ (ΡΠ°Π·Π΄Π΅Π» 3.6 ΠΠΠ ).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 8 |
from src.data.prompt import (
|
| 9 |
+
BASE_SYSTEM_PROMPT,
|
| 10 |
SYSTEM_PROMPT,
|
| 11 |
build_chat_messages,
|
| 12 |
+
build_system_message,
|
| 13 |
build_training_example,
|
| 14 |
build_user_message,
|
| 15 |
)
|
|
|
|
| 28 |
msgs = build_chat_messages("schema", "question")
|
| 29 |
assert len(msgs) == 2
|
| 30 |
assert msgs[0]["role"] == "system"
|
| 31 |
+
assert msgs[0]["content"] == BASE_SYSTEM_PROMPT
|
| 32 |
assert msgs[1]["role"] == "user"
|
| 33 |
|
| 34 |
|
|
|
|
| 37 |
assert len(msgs) == 3
|
| 38 |
assert msgs[2]["role"] == "assistant"
|
| 39 |
assert msgs[2]["content"] == "SELECT 1"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_legacy_system_prompt_alias():
|
| 43 |
+
assert SYSTEM_PROMPT == BASE_SYSTEM_PROMPT
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_system_message_without_vocabulary():
|
| 47 |
+
assert build_system_message(None) == BASE_SYSTEM_PROMPT
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_system_message_with_empty_vocabulary():
|
| 51 |
+
vocab = BusinessVocabulary.empty()
|
| 52 |
+
assert build_system_message(vocab) == BASE_SYSTEM_PROMPT
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_system_message_with_terms():
|
| 56 |
+
vocab = BusinessVocabulary(
|
| 57 |
+
company="ΠΠΠ Π ΠΎΠΌΠ°ΡΠΊΠ°",
|
| 58 |
+
terms={"Π²ΡΡΡΡΠΊΠ°": "SUM(orders.amount) WHERE orders.status = 'paid'"},
|
| 59 |
+
)
|
| 60 |
+
msg = build_system_message(vocab)
|
| 61 |
+
assert msg.startswith(BASE_SYSTEM_PROMPT)
|
| 62 |
+
assert "ΠΠΠ Π ΠΎΠΌΠ°ΡΠΊΠ°" in msg
|
| 63 |
+
assert "Π²ΡΡΡΡΠΊΠ°" in msg
|
| 64 |
+
assert "SUM(orders.amount)" in msg
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_chat_messages_with_vocabulary_keeps_user_clean():
|
| 68 |
+
vocab = BusinessVocabulary(
|
| 69 |
+
terms={"Π²ΡΡΡΡΠΊΠ°": "SUM(amount) WHERE status='paid'"},
|
| 70 |
+
)
|
| 71 |
+
msgs = build_chat_messages("schema", "ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ°?", vocabulary=vocab)
|
| 72 |
+
assert msgs[0]["role"] == "system"
|
| 73 |
+
assert "SUM(amount)" in msgs[0]["content"]
|
| 74 |
+
assert msgs[1]["role"] == "user"
|
| 75 |
+
assert "SUM(amount)" not in msgs[1]["content"]
|
| 76 |
+
assert "ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ°?" in msgs[1]["content"]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_training_example_with_vocabulary():
|
| 80 |
+
vocab = BusinessVocabulary(terms={"ΡΠΎΠΏ": "ORDER BY x DESC LIMIT 10"})
|
| 81 |
+
msgs = build_training_example(
|
| 82 |
+
"schema", "Π’ΠΎΠΏ ΠΊΠ»ΠΈΠ΅Π½ΡΠΎΠ²", "SELECT 1", vocabulary=vocab
|
| 83 |
+
)
|
| 84 |
+
assert len(msgs) == 3
|
| 85 |
+
assert msgs[0]["role"] == "system"
|
| 86 |
+
assert "ORDER BY x DESC" in msgs[0]["content"]
|
| 87 |
+
assert msgs[2]["role"] == "assistant"
|
tests/test_schema_provider.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Π’Π΅ΡΡΡ Π½Π° Π΅Π΄ΠΈΠ½ΡΠΉ SchemaProvider (ΡΠ°Π·Π΄Π΅Π» 4.2 Π°ΡΠ΄ΠΈΡΠ°).
|
| 2 |
+
|
| 3 |
+
ΠΠΎΠΊΡΡΠ²Π°ΡΡ ΠΎΠ±Π΅ ΡΠ΅Π°Π»ΠΈΠ·Π°ΡΠΈΠΈ: SpiderSchemaProvider (ΡΡΡΡΠΊΡΡΡΠ° PAUQ/Spider)
|
| 4 |
+
ΠΈ ConnectionSchemaProvider (ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ»ΡΠ½ΠΎΠ΅ ΠΏΠΎΠ΄ΠΊΠ»ΡΡΠ΅Π½ΠΈΠ΅ ΠΊ SQLite-ΡΠ°ΠΉΠ»Ρ).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sqlite3
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
|
| 12 |
+
from src.data.schema_provider import (
|
| 13 |
+
ColumnSchema,
|
| 14 |
+
ConnectionSchemaProvider,
|
| 15 |
+
SpiderSchemaProvider,
|
| 16 |
+
TableSchema,
|
| 17 |
+
render_tables,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
# Π€ΠΈΠΊΡΡΡΡΡ
|
| 23 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
|
| 25 |
+
@pytest.fixture
|
| 26 |
+
def spider_dir(tmp_path: Path) -> Path:
|
| 27 |
+
"""data/databases/uni/uni.sqlite + data/databases/sales/sales.sqlite."""
|
| 28 |
+
for db_id in ("uni", "sales"):
|
| 29 |
+
(tmp_path / db_id).mkdir()
|
| 30 |
+
db = tmp_path / db_id / f"{db_id}.sqlite"
|
| 31 |
+
conn = sqlite3.connect(db)
|
| 32 |
+
conn.execute(f"CREATE TABLE {db_id}_t (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")
|
| 33 |
+
conn.execute(f"INSERT INTO {db_id}_t VALUES (1, '{db_id}-row')")
|
| 34 |
+
conn.commit()
|
| 35 |
+
conn.close()
|
| 36 |
+
return tmp_path
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@pytest.fixture
|
| 40 |
+
def tiny_db(tmp_path: Path) -> Path:
|
| 41 |
+
db = tmp_path / "tiny.sqlite"
|
| 42 |
+
conn = sqlite3.connect(db)
|
| 43 |
+
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
|
| 44 |
+
conn.executemany("INSERT INTO users VALUES (?, ?)", [(1, "ΠΠ²Π°Π½"), (2, "ΠΠ½Π½Π°")])
|
| 45 |
+
conn.commit()
|
| 46 |
+
conn.close()
|
| 47 |
+
return db
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
# TableSchema.to_ddl
|
| 52 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
|
| 54 |
+
def test_table_schema_to_ddl_from_create_sql():
|
| 55 |
+
t = TableSchema(
|
| 56 |
+
name="t",
|
| 57 |
+
create_sql="CREATE TABLE t (id INT PRIMARY KEY, name TEXT)",
|
| 58 |
+
)
|
| 59 |
+
assert t.to_ddl() == "CREATE TABLE t (id INT PRIMARY KEY, name TEXT);"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def test_table_schema_to_ddl_from_columns():
|
| 63 |
+
t = TableSchema(
|
| 64 |
+
name="users",
|
| 65 |
+
columns=[
|
| 66 |
+
ColumnSchema(name="id", type="INTEGER", primary_key=True, nullable=False),
|
| 67 |
+
ColumnSchema(name="email", type="TEXT", nullable=False),
|
| 68 |
+
],
|
| 69 |
+
)
|
| 70 |
+
ddl = t.to_ddl()
|
| 71 |
+
assert ddl.startswith("CREATE TABLE users")
|
| 72 |
+
assert "id INTEGER PRIMARY KEY NOT NULL" in ddl
|
| 73 |
+
assert "email TEXT NOT NULL" in ddl
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
# SpiderSchemaProvider
|
| 78 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 79 |
+
|
| 80 |
+
def test_spider_lists_databases(spider_dir: Path):
|
| 81 |
+
p = SpiderSchemaProvider(spider_dir)
|
| 82 |
+
assert p.list_databases() == ["sales", "uni"]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_spider_db_path_resolves(spider_dir: Path):
|
| 86 |
+
p = SpiderSchemaProvider(spider_dir)
|
| 87 |
+
path = p.db_path("uni")
|
| 88 |
+
assert path.exists()
|
| 89 |
+
assert path.name == "uni.sqlite"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def test_spider_db_path_raises_on_missing(spider_dir: Path):
|
| 93 |
+
p = SpiderSchemaProvider(spider_dir)
|
| 94 |
+
with pytest.raises(FileNotFoundError):
|
| 95 |
+
p.db_path("nonexistent")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_spider_get_tables_returns_tableschema(spider_dir: Path):
|
| 99 |
+
p = SpiderSchemaProvider(spider_dir)
|
| 100 |
+
tables = p.get_tables("uni")
|
| 101 |
+
assert len(tables) == 1
|
| 102 |
+
assert isinstance(tables[0], TableSchema)
|
| 103 |
+
assert tables[0].name == "uni_t"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_spider_render_schema_has_create(spider_dir: Path):
|
| 107 |
+
p = SpiderSchemaProvider(spider_dir)
|
| 108 |
+
text = p.render_schema("uni")
|
| 109 |
+
assert "CREATE TABLE" in text
|
| 110 |
+
assert "uni_t" in text
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
+
# ConnectionSchemaProvider
|
| 115 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
+
|
| 117 |
+
def test_connection_lists_tables(tiny_db: Path):
|
| 118 |
+
p = ConnectionSchemaProvider(str(tiny_db))
|
| 119 |
+
assert p.list_tables() == ["users"]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_connection_get_tables_columns(tiny_db: Path):
|
| 123 |
+
p = ConnectionSchemaProvider(str(tiny_db))
|
| 124 |
+
tables = p.get_tables()
|
| 125 |
+
assert len(tables) == 1
|
| 126 |
+
cols = {c.name for c in tables[0].columns}
|
| 127 |
+
assert cols == {"id", "name"}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_connection_render_schema_with_samples(tiny_db: Path):
|
| 131 |
+
p = ConnectionSchemaProvider(str(tiny_db))
|
| 132 |
+
text = p.render_schema(include_samples=True)
|
| 133 |
+
assert "CREATE TABLE users" in text
|
| 134 |
+
assert "ΠΠ²Π°Π½" in text or "ΠΠ½Π½Π°" in text
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def test_connection_test_connection(tiny_db: Path):
|
| 138 |
+
p = ConnectionSchemaProvider(str(tiny_db))
|
| 139 |
+
assert p.test_connection() is True
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 143 |
+
# Π¦Π΅ΠΏΠΎΡΠΊΠ° SpiderSchemaProvider.for_database β ConnectionSchemaProvider
|
| 144 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 145 |
+
|
| 146 |
+
def test_spider_for_database_returns_connection_provider(spider_dir: Path):
|
| 147 |
+
p = SpiderSchemaProvider(spider_dir)
|
| 148 |
+
sub = p.for_database("sales")
|
| 149 |
+
assert isinstance(sub, ConnectionSchemaProvider)
|
| 150 |
+
text = sub.render_schema()
|
| 151 |
+
assert "sales_t" in text
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 155 |
+
# render_tables β ΠΎΠ±ΡΠ°Ρ ΡΡΠΈΠ»ΠΈΡΠ°
|
| 156 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 157 |
+
|
| 158 |
+
def test_render_tables_groups_ddl_and_samples():
|
| 159 |
+
tables = [
|
| 160 |
+
TableSchema(
|
| 161 |
+
name="x",
|
| 162 |
+
columns=[ColumnSchema(name="id", type="INT")],
|
| 163 |
+
sample_rows=[(1,), (2,)],
|
| 164 |
+
),
|
| 165 |
+
]
|
| 166 |
+
text = render_tables(tables, include_samples=True)
|
| 167 |
+
assert "CREATE TABLE x" in text
|
| 168 |
+
assert "ΠΡΠΈΠΌΠ΅ΡΡ ΡΡΡΠΎΠΊ" in text
|
| 169 |
+
assert "(1," in text and "(2," in text
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def test_render_tables_no_samples():
|
| 173 |
+
tables = [
|
| 174 |
+
TableSchema(
|
| 175 |
+
name="x",
|
| 176 |
+
columns=[ColumnSchema(name="id", type="INT")],
|
| 177 |
+
sample_rows=[(1,)],
|
| 178 |
+
),
|
| 179 |
+
]
|
| 180 |
+
text = render_tables(tables, include_samples=False)
|
| 181 |
+
assert "CREATE TABLE x" in text
|
| 182 |
+
assert "ΠΡΠΈΠΌΠ΅ΡΡ ΡΡΡΠΎΠΊ" not in text
|
tests/test_vocabulary.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Π’Π΅ΡΡΡ Π½Π° BusinessVocabulary (ΡΠ°Π·Π΄Π΅Π» 3.6 ΠΏΠΎΡΡΠ½ΠΈΡΠ΅Π»ΡΠ½ΠΎΠΉ Π·Π°ΠΏΠΈΡΠΊΠΈ)."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
# ΠΠ°Π³ΡΡΠ·ΠΊΠ°
|
| 12 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
+
|
| 14 |
+
def test_empty_vocabulary_is_falsy():
|
| 15 |
+
vocab = BusinessVocabulary.empty()
|
| 16 |
+
assert not vocab
|
| 17 |
+
assert vocab.render_system_context() == ""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_from_dict_basic():
|
| 21 |
+
vocab = BusinessVocabulary.from_dict(
|
| 22 |
+
{
|
| 23 |
+
"company": "ΠΠΠ Π’Π΅ΡΡ",
|
| 24 |
+
"terms": {"Π²ΡΡΡΡΠΊΠ°": "SUM(amount)"},
|
| 25 |
+
"filters": {"paid": "status='paid'"},
|
| 26 |
+
"notes": ["ΠΠ΅ΡΠΈΠΎΠ΄: 2025-01-01β2026-04-30"],
|
| 27 |
+
}
|
| 28 |
+
)
|
| 29 |
+
assert bool(vocab)
|
| 30 |
+
assert vocab.company == "ΠΠΠ Π’Π΅ΡΡ"
|
| 31 |
+
assert vocab.terms == {"Π²ΡΡΡΡΠΊΠ°": "SUM(amount)"}
|
| 32 |
+
assert vocab.filters == {"paid": "status='paid'"}
|
| 33 |
+
assert vocab.notes == ["ΠΠ΅ΡΠΈΠΎΠ΄: 2025-01-01β2026-04-30"]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_from_yaml_roundtrip(tmp_path: Path):
|
| 37 |
+
original = BusinessVocabulary(
|
| 38 |
+
company="ΠΠΠ Π ΠΎΠΌΠ°ΡΠΊΠ°",
|
| 39 |
+
terms={"Π²ΡΡΡΡΠΊΠ°": "SUM(orders.amount)"},
|
| 40 |
+
filters={"paid_only": "orders.status='paid'"},
|
| 41 |
+
notes=["amount ΡΠΎΠ»ΡΠΊΠΎ Π² orders"],
|
| 42 |
+
)
|
| 43 |
+
path = tmp_path / "vocab.yaml"
|
| 44 |
+
original.save_yaml(path)
|
| 45 |
+
|
| 46 |
+
loaded = BusinessVocabulary.from_yaml(path)
|
| 47 |
+
assert loaded.company == original.company
|
| 48 |
+
assert loaded.terms == original.terms
|
| 49 |
+
assert loaded.filters == original.filters
|
| 50 |
+
assert loaded.notes == original.notes
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_from_yaml_missing_file(tmp_path: Path):
|
| 54 |
+
with pytest.raises(FileNotFoundError):
|
| 55 |
+
BusinessVocabulary.from_yaml(tmp_path / "does_not_exist.yaml")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
# enrich_prompt β ΠΎΠ±ΡΠ°ΡΠ½Π°Ρ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΠΎΡΡΡ ΡΠΎ ΡΡΠ°ΡΡΠΌ Streamlit-ΠΊΠΎΠ΄ΠΎΠΌ
|
| 60 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
|
| 62 |
+
def test_enrich_prompt_pass_through_when_empty():
|
| 63 |
+
vocab = BusinessVocabulary.empty()
|
| 64 |
+
assert vocab.enrich_prompt("ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ°?") == "ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ°?"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_enrich_prompt_adds_term_definition():
|
| 68 |
+
vocab = BusinessVocabulary(
|
| 69 |
+
terms={"Π²ΡΡΡΡΠΊΠ°": "SUM(amount) WHERE status='paid'"},
|
| 70 |
+
)
|
| 71 |
+
enriched = vocab.enrich_prompt("ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ° Π·Π° ΡΠ½Π²Π°ΡΡ?")
|
| 72 |
+
assert "Π²ΡΡΡΡΠΊΠ°" in enriched
|
| 73 |
+
assert "SUM(amount)" in enriched
|
| 74 |
+
assert "ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ° Π·Π° ΡΠ½Π²Π°ΡΡ?" in enriched
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_enrich_prompt_case_insensitive_match():
|
| 78 |
+
vocab = BusinessVocabulary(terms={"Π²ΡΡΡΡΠΊΠ°": "SUM(amount)"})
|
| 79 |
+
enriched = vocab.enrich_prompt("ΠΠ«Π Π£Π§ΠΠ Π² ΡΡΠΎΠΌ ΠΌΠ΅ΡΡΡΠ΅?")
|
| 80 |
+
assert "SUM(amount)" in enriched
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def test_enrich_prompt_ignores_unrelated_terms():
|
| 84 |
+
vocab = BusinessVocabulary(
|
| 85 |
+
terms={
|
| 86 |
+
"Π²ΡΡΡΡΠΊΠ°": "SUM(amount)",
|
| 87 |
+
"ΡΠΎΠΏ ΠΊΠ»ΠΈΠ΅Π½ΡΠΎΠ²": "ORDER BY SUM(amount) DESC",
|
| 88 |
+
},
|
| 89 |
+
)
|
| 90 |
+
enriched = vocab.enrich_prompt("ΠΠ°ΠΊΠ°Ρ Π²ΡΡΡΡΠΊΠ°?")
|
| 91 |
+
# Π Π΅Π»Π΅Π²Π°Π½ΡΠ½ΡΠΉ ΡΠ΅ΡΠΌΠΈΠ½ ΠΏΠΎΠ΄ΠΌΠ΅ΡΠ°Π½, ΠΏΠΎΡΡΠΎΡΠΎΠ½Π½ΠΈΠΉ β Π½Π΅Ρ
|
| 92 |
+
assert "SUM(amount)" in enriched
|
| 93 |
+
assert "ORDER BY SUM(amount) DESC" not in enriched
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 97 |
+
# render_system_context β ΡΠΎ, ΡΡΠΎ ΠΏΠΎΠ΄ΠΌΠ΅ΡΠΈΠ²Π°Π΅ΡΡΡ Π² system-ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΠ΅
|
| 98 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 99 |
+
|
| 100 |
+
def test_render_system_context_empty():
|
| 101 |
+
assert BusinessVocabulary.empty().render_system_context() == ""
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_render_system_context_with_company():
|
| 105 |
+
vocab = BusinessVocabulary(
|
| 106 |
+
company="ΠΠΠ Π ΠΎΠΌΠ°ΡΠΊΠ°",
|
| 107 |
+
terms={"Π²ΡΡΡΡΠΊΠ°": "SUM(amount)"},
|
| 108 |
+
)
|
| 109 |
+
ctx = vocab.render_system_context()
|
| 110 |
+
assert "ΠΠΠ Π ΠΎΠΌΠ°ΡΠΊΠ°" in ctx
|
| 111 |
+
assert "Π²ΡΡΡΡΠΊΠ°" in ctx
|
| 112 |
+
assert "SUM(amount)" in ctx
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_render_system_context_with_filters_and_notes():
|
| 116 |
+
vocab = BusinessVocabulary(
|
| 117 |
+
filters={"ΡΠΎΠ»ΡΠΊΠΎ_ΠΎΠΏΠ»Π°ΡΠ΅Π½Π½ΡΠ΅": "orders.status='paid'"},
|
| 118 |
+
notes=["amount ΡΠΎΠ»ΡΠΊΠΎ Π² orders"],
|
| 119 |
+
)
|
| 120 |
+
ctx = vocab.render_system_context()
|
| 121 |
+
assert "ΡΠΎΠ»ΡΠΊΠΎ_ΠΎΠΏΠ»Π°ΡΠ΅Π½Π½ΡΠ΅" in ctx
|
| 122 |
+
assert "orders.status='paid'" in ctx
|
| 123 |
+
assert "amount ΡΠΎΠ»ΡΠΊΠΎ Π² orders" in ctx
|