Tyycha commited on
Commit
cc2ed2f
Β·
1 Parent(s): b167476
.gitignore CHANGED
@@ -68,3 +68,7 @@ htmlcov/
68
  # Outputs
69
  outputs/
70
  results/
 
 
 
 
 
68
  # Outputs
69
  outputs/
70
  results/
71
+
72
+ # Π”ΠΎΠΊΡƒΠΌΠ΅Π½Ρ‚Ρ‹ Π’ΠšΠ 
73
+ *.docx
74
+ *.pptx
Dockerfile CHANGED
@@ -1,24 +1,43 @@
 
 
 
 
 
 
1
  FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- git \
9
  && rm -rf /var/lib/apt/lists/*
10
 
 
 
 
 
11
  COPY requirements.txt ./
12
- RUN pip3 install -r requirements.txt
13
 
14
  COPY . .
15
 
16
- ENV LORA_ADAPTER_PATH=Tyycha/qwen-coder-pauq-lora
17
- ENV BASE_MODEL_NAME=Qwen/Qwen2.5-Coder-3B-Instruct
18
- ENV DEVICE=cpu
 
 
 
 
 
 
19
 
20
  EXPOSE 7860
21
 
22
- HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
 
 
 
23
 
24
- ENTRYPOINT ["streamlit", "run", "streamlit_app.py", "--server.port=7860", "--server.address=0.0.0.0"]
 
1
+ # ΠžΠ±Ρ€Π°Π· для HuggingFace Spaces (Docker SDK).
2
+ # ЗапускаСт ΠΎΠ±Π° процСсса Ru2SQL Π² ΠΎΠ΄Π½ΠΎΠΌ ΠΊΠΎΠ½Ρ‚Π΅ΠΉΠ½Π΅Ρ€Π΅:
3
+ # - FastAPI Π½Π° 127.0.0.1:8000 (Π²Π½ΡƒΡ‚Ρ€Π΅Π½Π½ΠΈΠΉ)
4
+ # - Streamlit Π½Π° 0.0.0.0:7860 (внСшний, HF Spaces Ρ‡ΠΈΡ‚Π°Π΅Ρ‚ с этого ΠΏΠΎΡ€Ρ‚Π°)
5
+ # ΠžΡ€ΠΊΠ΅ΡΡ‚Ρ€ΠΈΡ€ΡƒΠ΅Ρ‚ ΠΈΡ… scripts/run_app.py.
6
+
7
  FROM python:3.12-slim
8
 
9
  WORKDIR /app
10
 
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ build-essential \
13
+ curl \
14
+ git \
15
  && rm -rf /var/lib/apt/lists/*
16
 
17
+ # Π‘Π½Π°Ρ‡Π°Π»Π° ставим torch с ΠΏΠΎΠ΄Π΄Π΅Ρ€ΠΆΠΊΠΎΠΉ CUDA 12.1 β€” это Π½ΡƒΠΆΠ½ΠΎ для T4 Small Π½Π° HF.
18
+ # На CPU-ΠΎΠΊΡ€ΡƒΠΆΠ΅Π½ΠΈΠΈ этот ΠΆΠ΅ torch автоматичСски Ρ€Π°Π±ΠΎΡ‚Π°Π΅Ρ‚ Ρ‡Π΅Ρ€Π΅Π· CPU-бэкСнд.
19
+ RUN pip3 install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu121
20
+
21
  COPY requirements.txt ./
22
+ RUN pip3 install --no-cache-dir -r requirements.txt
23
 
24
  COPY . .
25
 
26
+ # Π”Π΅Ρ„ΠΎΠ»Ρ‚Ρ‹ для HF Spaces. ΠŸΠ΅Ρ€Π΅ΠΏΠΈΡΡ‹Π²Π°ΡŽΡ‚ΡΡ Variables/Secrets Π² настройках Space.
27
+ ENV LORA_ADAPTER_PATH=Tyycha/qwen-coder-pauq-lora \
28
+ BASE_MODEL_NAME=Qwen/Qwen2.5-Coder-3B-Instruct \
29
+ DEVICE=cuda \
30
+ API_HOST=127.0.0.1 \
31
+ API_PORT=8000 \
32
+ STREAMLIT_HOST=0.0.0.0 \
33
+ STREAMLIT_PORT=7860 \
34
+ RU2SQL_API_URL=http://127.0.0.1:8000
35
 
36
  EXPOSE 7860
37
 
38
+ # start-period=600s β€” Ρƒ ΠΏΠ΅Ρ€Π²ΠΎΠ³ΠΎ запуска Π΅ΡΡ‚ΡŒ Π΄ΠΎ 10 ΠΌΠΈΠ½ΡƒΡ‚ Π½Π° скачиваниС ΠΌΠΎΠ΄Π΅Π»ΠΈ
39
+ # с HuggingFace. Π”Π°Π»ΡŒΡˆΠ΅ healthcheck ΠΎΠΏΡ€Π°ΡˆΠΈΠ²Π°Π΅Ρ‚ Streamlit ΠΊΠ°ΠΆΠ΄Ρ‹Π΅ 30 сСкунд.
40
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=600s --retries=3 \
41
+ CMD curl --fail http://localhost:7860/_stcore/health || exit 1
42
 
43
+ ENTRYPOINT ["python", "scripts/run_app.py"]
README.md CHANGED
@@ -12,23 +12,43 @@ pinned: false
12
  ГСнСративная модСль для прСобразования вопросов Π½Π° русском языкС Π² SQL-запросы.
13
  ΠŸΡ€Π°ΠΊΡ‚ΠΈΡ‡Π΅ΡΠΊΠ°Ρ Ρ‡Π°ΡΡ‚ΡŒ Π’ΠšΠ , Π½Π°ΠΏΡ€Π°Π²Π»Π΅Π½ΠΈΠ΅ Β«ΠŸΡ€ΠΎΠ³Ρ€Π°ΠΌΠΌΠ½Π°Ρ инТСнСрия», 4 курс.
14
 
15
- **Π‘Ρ‚Π΅ΠΊ:** Python 3.10+, PyTorch, transformers, PEFT (LoRA), FastAPI, sqlglot.
16
  **Основная модСль:** Qwen2.5-Coder-3B-Instruct, дообучСнная ΠΌΠ΅Ρ‚ΠΎΠ΄ΠΎΠΌ QLoRA Π½Π° датасСтС PAUQ.
17
  **Π‘Ρ€Π°Π²Π½Π΅Π½ΠΈΠ΅:** ruT5-base baseline + GigaChat API.
18
 
19
- Π‘ΠΌ. `plan_VKR_text2sql_ru.md` для ΠΏΠΎΠ»Π½ΠΎΠ³ΠΎ ΠΏΠ»Π°Π½Π° Ρ€Π°Π±ΠΎΡ‚ Π½Π° мСсяц.
20
 
21
  ---
22
 
23
- ## Быстрый старт (Π½Π° дСсктопС)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  ### 1. Установка
26
 
27
  ```bash
28
- # Установи uv (https://docs.astral.sh/uv/) Ссли Π΅Ρ‰Ρ‘ Π½Π΅Ρ‚
29
  pip install uv
30
-
31
- # ΠšΠ»ΠΎΠ½ΠΈΡ€ΡƒΠΉ Ρ€Π΅ΠΏΠΎΠ·ΠΈΡ‚ΠΎΡ€ΠΈΠΉ ΠΈ установи зависимости
32
  git clone <Ρ‚Π²ΠΎΠΉ-Ρ€Π΅ΠΏΠΎ> ru2sql
33
  cd ru2sql
34
  uv venv
@@ -44,36 +64,97 @@ copy .env.example .env # Windows
44
  # cp .env.example .env # Linux/Mac
45
  ```
46
 
47
- ΠžΡ‚ΠΊΡ€ΠΎΠΉ `.env` ΠΈ Π·Π°ΠΏΠΎΠ»Π½ΠΈ ΠΊΠ»ΡŽΡ‡ΠΈ (ΠΌΠΈΠ½ΠΈΠΌΡƒΠΌ `GIGACHAT_API_KEY` для baseline-сравнСния, ΠΎΡΡ‚Π°Π»ΡŒΠ½ΠΎΠ΅ ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎ).
 
 
 
 
 
 
 
 
48
 
49
- ### 3. Π‘ΠΊΠ°Ρ‡Π°ΠΉ PAUQ
50
 
51
  ```bash
52
- git clone https://github.com/ai-forever/pauq.git data/pauq_repo
53
- # Π—Π°Ρ‚Π΅ΠΌ Ρ€Π°Π·Π»ΠΎΠΆΠΈ train.json/dev.json/test.json Π² data/pauq/
54
- # ΠΈ SQLite-Ρ„Π°ΠΉΠ»Ρ‹ Π² data/databases/{db_id}/{db_id}.sqlite
55
  ```
56
 
57
- ### 4. ВСсты
58
 
 
 
 
59
  ```bash
60
- pytest -v
61
  ```
62
 
63
- ВСсты для ΠΌΠΎΠ΄ΡƒΠ»Π΅ΠΉ `prompt`, `postprocess`, `metrics`, `schema` Π΄ΠΎΠ»ΠΆΠ½Ρ‹ ΠΏΡ€ΠΎΡ…ΠΎΠ΄ΠΈΡ‚ΡŒ
64
- Π±Π΅Π· скачивания ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈ датасСта.
 
 
65
 
66
- ### 5. Запуск API
67
 
 
 
 
68
  ```bash
69
  uvicorn src.api.main:app --reload
70
  # Swagger UI: http://127.0.0.1:8000/docs
71
  ```
72
 
73
- ΠŸΡ€ΠΈ ΠΏΠ΅Ρ€Π²ΠΎΠΌ запускС модСль Qwen2.5-Coder-3B (~6 GB) скачаСтся ΠΈΠ· HuggingFace Hub.
74
- На CPU инфСрСнс Π·Π°Π½ΠΈΠΌΠ°Π΅Ρ‚ 15–30 сСкунд Π½Π° запрос β€” это ΠΎΠΆΠΈΠ΄Π°Π΅ΠΌΠΎ.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- ### 6. Запрос ΠΊ API
77
 
78
  ```bash
79
  curl -X POST http://127.0.0.1:8000/generate-sql \
@@ -90,7 +171,7 @@ curl -X POST http://127.0.0.1:8000/generate-sql \
90
 
91
  Π¨Π°Π³ΠΈ:
92
  1. ΠžΡ‚ΠΊΡ€ΠΎΠΉ `notebooks/kaggle_train_qwen_qlora.ipynb` Π½Π° kaggle.com.
93
- 2. Π’ Settings Π²Ρ‹Π±Π΅Ρ€ΠΈ Accelerator: GPU T4 x1 (ΠΈΠ»ΠΈ x2 для скорости).
94
  3. Add-ons β†’ Secrets β†’ добавь `HF_TOKEN` ΠΈ `WANDB_API_KEY`.
95
  4. Запусти всС ячСйки. Π’Ρ€Π΅Π½ΠΈΡ€ΠΎΠ²ΠΊΠ° ~4–6 часов.
96
  5. По Π·Π°Π²Π΅Ρ€ΡˆΠ΅Π½ΠΈΠΈ Π°Π΄Π°ΠΏΡ‚Π΅Ρ€ ΠΏΡƒΡˆΠΈΡ‚ΡΡ Π½Π° Ρ‚Π²ΠΎΠΉ ΠΏΡ€ΠΈΠ²Π°Ρ‚Π½Ρ‹ΠΉ HF-Ρ€Π΅ΠΏΠΎ.
@@ -109,32 +190,47 @@ curl -X POST http://127.0.0.1:8000/generate-sql \
109
 
110
  ```
111
  ru2sql/
112
- β”œβ”€β”€ pyproject.toml # зависимости (uv)
113
  β”œβ”€β”€ .env.example # шаблон ΠΊΠΎΠ½Ρ„ΠΈΠ³ΡƒΡ€Π°Ρ†ΠΈΠΈ
114
- β”œβ”€β”€ plan_VKR_text2sql_ru.md # ΠΏΠ»Π°Π½ Ρ€Π°Π±ΠΎΡ‚ Π½Π° мСсяц
115
  β”œβ”€β”€ notebooks/
116
  β”‚ └── kaggle_train_qwen_qlora.ipynb
 
 
 
 
 
117
  β”œβ”€β”€ src/
118
  β”‚ β”œβ”€β”€ config.py # настройки Ρ‡Π΅Ρ€Π΅Π· pydantic-settings
119
  β”‚ β”œβ”€β”€ data/
120
  β”‚ β”‚ β”œβ”€β”€ loader.py # Ρ‡Ρ‚Π΅Π½ΠΈΠ΅ PAUQ JSON
121
- β”‚ β”‚ β”œβ”€β”€ schema.py # SchemaRetriever (DDL ΠΈΠ· SQLite)
 
122
  β”‚ β”‚ └── prompt.py # PromptBuilder + chat-template
 
 
 
 
 
123
  β”‚ β”œβ”€β”€ models/
124
  β”‚ β”‚ β”œβ”€β”€ inference.py # InferenceEngine (модСль + LoRA)
125
- β”‚ β”‚ └── postprocess.py # очистка SQL + sqlglot валидация
126
  β”‚ β”œβ”€β”€ evaluation/
127
- β”‚ β”‚ β”œβ”€β”€ metrics.py # Exact Match + Execution Accuracy
128
- β”‚ β”‚ └── evaluate.py # CLI для ΠΏΡ€ΠΎΠ³ΠΎΠ½Π° Π½Π° split'Π΅
129
  β”‚ └── api/
130
- β”‚ β”œβ”€β”€ main.py # FastAPI app
131
  β”‚ β”œβ”€β”€ schemas.py # Pydantic-ΠΌΠΎΠ΄Π΅Π»ΠΈ
132
  β”‚ └── dependencies.py # lifespan + DI
133
- └── tests/
 
134
  β”œβ”€β”€ test_prompt.py
135
  β”œβ”€β”€ test_postprocess.py
136
  β”œβ”€β”€ test_metrics.py
137
- └── test_schema.py
 
 
 
138
  ```
139
 
140
  ---
@@ -153,13 +249,13 @@ python -m src.evaluation.evaluate --split dev --limit 50
153
 
154
  ---
155
 
156
- ## ΠœΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ (ΠΏΠ»Π°Π½ΠΈΡ€ΡƒΠ΅ΠΌΡ‹Π΅)
157
 
158
  | МодСль | EM | Execution Accuracy |
159
  |---|---|---|
160
- | ruT5-base (baseline) | 25–35% | 30–40% |
161
- | **Qwen2.5-Coder-3B + QLoRA** | **50–60%** | **55–70%** |
162
- | GigaChat API (zero-shot) | 55–70% | 65–80% |
163
 
164
  ---
165
 
 
12
  ГСнСративная модСль для прСобразования вопросов Π½Π° русском языкС Π² SQL-запросы.
13
  ΠŸΡ€Π°ΠΊΡ‚ΠΈΡ‡Π΅ΡΠΊΠ°Ρ Ρ‡Π°ΡΡ‚ΡŒ Π’ΠšΠ , Π½Π°ΠΏΡ€Π°Π²Π»Π΅Π½ΠΈΠ΅ Β«ΠŸΡ€ΠΎΠ³Ρ€Π°ΠΌΠΌΠ½Π°Ρ инТСнСрия», 4 курс.
14
 
15
+ **Π‘Ρ‚Π΅ΠΊ:** Python 3.10+, PyTorch, transformers, PEFT (LoRA), FastAPI, Streamlit, sqlglot.
16
  **Основная модСль:** Qwen2.5-Coder-3B-Instruct, дообучСнная ΠΌΠ΅Ρ‚ΠΎΠ΄ΠΎΠΌ QLoRA Π½Π° датасСтС PAUQ.
17
  **Π‘Ρ€Π°Π²Π½Π΅Π½ΠΈΠ΅:** ruT5-base baseline + GigaChat API.
18
 
19
+ Π‘ΠΌ. `plan_VKR_text2sql_ru.md` для ΠΏΠΎΠ»Π½ΠΎΠ³ΠΎ ΠΏΠ»Π°Π½Π° Ρ€Π°Π±ΠΎΡ‚.
20
 
21
  ---
22
 
23
+ ## АрхитСктура
24
+
25
+ ```
26
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” HTTP β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
27
+ β”‚ Streamlit-ΠΊΠ»ΠΈΠ΅Π½Ρ‚ β”‚ ─────────► β”‚ FastAPI REST API β”‚
28
+ β”‚ (ΠΏΠΎΡ€Ρ‚ 8501) β”‚ β”‚ (ΠΏΠΎΡ€Ρ‚ 8000) β”‚
29
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
30
+ β”‚
31
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
32
+ β–Ό β–Ό β–Ό
33
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
34
+ β”‚ InferenceEngine β”‚ β”‚ SchemaProvi- β”‚ β”‚ BusinessVocab- β”‚
35
+ β”‚ Qwen + LoRA β”‚ β”‚ der + Sql- β”‚ β”‚ ulary (YAML) β”‚
36
+ β”‚ β”‚ β”‚ Executor β”‚ β”‚ β”‚
37
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
38
+ ```
39
+
40
+ Streamlit-интСрфСйс Π½Π΅ Π²Ρ‹Π·Ρ‹Π²Π°Π΅Ρ‚ модСль Π½Π°ΠΏΡ€ΡΠΌΡƒΡŽ β€” ΠΎΠ½ обращаСтся ΠΊ REST API
41
+ Ρ‡Π΅Ρ€Π΅Π· `httpx`. Π­Ρ‚ΠΎ позволяСт Π·Π°ΠΏΡƒΡΠΊΠ°Ρ‚ΡŒ UI ΠΈ инфСрСнс Π½Π° Ρ€Π°Π·Π½Ρ‹Ρ… ΠΌΠ°ΡˆΠΈΠ½Π°Ρ…,
42
+ Π° Ρ‚Π°ΠΊΠΆΠ΅ ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π°Ρ‚ΡŒ ΠΊ API Π»ΡŽΠ±Ρ‹Ρ… сторонних ΠΊΠ»ΠΈΠ΅Π½Ρ‚ΠΎΠ².
43
+
44
+ ---
45
+
46
+ ## Быстрый старт
47
 
48
  ### 1. Установка
49
 
50
  ```bash
 
51
  pip install uv
 
 
52
  git clone <Ρ‚Π²ΠΎΠΉ-Ρ€Π΅ΠΏΠΎ> ru2sql
53
  cd ru2sql
54
  uv venv
 
64
  # cp .env.example .env # Linux/Mac
65
  ```
66
 
67
+ Π—Π°ΠΏΠΎΠ»Π½ΠΈ Π² `.env`:
68
+ - `BASE_MODEL_NAME` (ΠΏΠΎ ΡƒΠΌΠΎΠ»Ρ‡Π°Π½ΠΈΡŽ Qwen/Qwen2.5-Coder-3B-Instruct),
69
+ - `LORA_ADAPTER_PATH` (локальная папка или HF-repo),
70
+ - ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎ β€” API-ΠΊΠ»ΡŽΡ‡ GigaChat для baseline-сравнСния.
71
+
72
+ Если HuggingFace нСдоступСн ΠΈΠ· вашСй сСти, добавь:
73
+ ```
74
+ HF_ENDPOINT=https://hf-mirror.com
75
+ ```
76
 
77
+ ### 3. ВСсты
78
 
79
  ```bash
80
+ pytest -v
 
 
81
  ```
82
 
83
+ ОТидаСмо: 80+ Π·Π΅Π»Ρ‘Π½Ρ‹Ρ… тСстов.
84
 
85
+ ### 4. Smoke-ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ°
86
+
87
+ Быстрая (5 сСк, Π±Π΅Π· ΠΌΠΎΠ΄Π΅Π»ΠΈ):
88
  ```bash
89
+ python scripts/smoke_local.py
90
  ```
91
 
92
+ Полная (с Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠΎΠΉ Qwen, ~5 ΠΌΠΈΠ½ΡƒΡ‚ Π½Π° CPU):
93
+ ```bash
94
+ python scripts/smoke_local.py --with-model
95
+ ```
96
 
97
+ ### 5. Запуск прилоТСния
98
 
99
+ НуТны **Π΄Π²Π° процСсса** β€” API ΠΈ UI:
100
+
101
+ **Окно 1** β€” REST API:
102
  ```bash
103
  uvicorn src.api.main:app --reload
104
  # Swagger UI: http://127.0.0.1:8000/docs
105
  ```
106
 
107
+ **Окно 2** β€” Streamlit-интСрфСйс:
108
+ ```bash
109
+ streamlit run streamlit_app.py
110
+ # UI: http://127.0.0.1:8501
111
+ ```
112
+
113
+ ΠŸΡ€ΠΈ ΠΏΠ΅Ρ€Π²ΠΎΠΌ запускС модСль Qwen2.5-Coder-3B (~6 GB) скачиваСтся ΠΈΠ· HuggingFace
114
+ Hub. На CPU инфСрСнс ΠΎΠ΄Π½ΠΎΠ³ΠΎ запроса Π·Π°Π½ΠΈΠΌΠ°Π΅Ρ‚ 15–30 сСкунд β€” это ΠΎΠΆΠΈΠ΄Π°Π΅ΠΌΠΎ.
115
+
116
+ АдрСс API ΠΌΠΎΠΆΠ½ΠΎ ΠΏΠ΅Ρ€Π΅ΠΎΠΏΡ€Π΅Π΄Π΅Π»ΠΈΡ‚ΡŒ ΠΏΠ΅Ρ€Π΅ΠΌΠ΅Π½Π½ΠΎΠΉ окруТСния:
117
+ ```bash
118
+ set RU2SQL_API_URL=http://192.168.1.10:8000 # Windows
119
+ # export RU2SQL_API_URL=http://192.168.1.10:8000 # Linux/Mac
120
+ ```
121
+
122
+ ---
123
+
124
+ ## REST API
125
+
126
+ ### Π‘Π°Π·ΠΎΠ²Ρ‹Π΅ эндпоинты
127
+
128
+ | ΠœΠ΅Ρ‚ΠΎΠ΄ | ΠŸΡƒΡ‚ΡŒ | НазначСниС |
129
+ |---|---|---|
130
+ | GET | `/health` | статус сСрвиса ΠΈ Π·Π°Π³Ρ€ΡƒΠΆΠ΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ |
131
+ | GET | `/databases` | список Π‘Π” ΠΈΠ· data/databases (PAUQ-структура) |
132
+ | POST | `/generate-sql` | гСнСрация SQL ΠΏΠΎ db_id ΠΈΠ· PAUQ |
133
+ | POST | `/schema` | схСма ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” ΠΏΠΎ connection string |
134
+ | POST | `/query` | ΠΏΠΎΠ»Π½Ρ‹ΠΉ pipeline для ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” |
135
+
136
+ ### ΠŸΡ€ΠΈΠΌΠ΅Ρ€: запрос ΠΊ ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π”
137
+
138
+ ```bash
139
+ curl -X POST http://127.0.0.1:8000/query \
140
+ -H "Content-Type: application/json" \
141
+ -d '{
142
+ "question": "Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ° Π·Π° 2026 Π³ΠΎΠ΄?",
143
+ "connection_string": "sqlite:///data/demo/sales.sqlite",
144
+ "execute": true,
145
+ "vocabulary": {
146
+ "company": "Π”Π΅ΠΌΠΎ-ΠΌΠ°Π³Π°Π·ΠΈΠ½",
147
+ "terms": {"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(orders.amount) WHERE status='paid'"}
148
+ }
149
+ }'
150
+ ```
151
+
152
+ ΠžΡ‚Π²Π΅Ρ‚ содСрТит `sql`, `raw_output`, `is_valid_sql`, `gen_time_seconds` ΠΈ
153
+ ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎ `execution` с Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚Π°ΠΌΠΈ выполнСния. ΠŸΠ΅Ρ€Π΅Π΄ исполнСниСм
154
+ SQL ΠΏΡ€ΠΎΡ…ΠΎΠ΄ΠΈΡ‚ AST-ΡƒΡ€ΠΎΠ²Π½Π΅Π²ΡƒΡŽ ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΡƒ (см. `is_select_only` Π²
155
+ `src/models/postprocess.py`) β€” DDL ΠΈ DML Π½Π° Π‘Π” Ρ‡Π΅Ρ€Π΅Π· API Π½Π΅Π²ΠΎΠ·ΠΌΠΎΠΆΠ½Ρ‹.
156
 
157
+ ### ΠŸΡ€ΠΈΠΌΠ΅Ρ€: PAUQ-Ρ€Π΅ΠΆΠΈΠΌ (старый эндпоинт)
158
 
159
  ```bash
160
  curl -X POST http://127.0.0.1:8000/generate-sql \
 
171
 
172
  Π¨Π°Π³ΠΈ:
173
  1. ΠžΡ‚ΠΊΡ€ΠΎΠΉ `notebooks/kaggle_train_qwen_qlora.ipynb` Π½Π° kaggle.com.
174
+ 2. Π’ Settings Π²Ρ‹Π±Π΅Ρ€ΠΈ Accelerator: GPU T4 x1.
175
  3. Add-ons β†’ Secrets β†’ добавь `HF_TOKEN` ΠΈ `WANDB_API_KEY`.
176
  4. Запусти всС ячСйки. Π’Ρ€Π΅Π½ΠΈΡ€ΠΎΠ²ΠΊΠ° ~4–6 часов.
177
  5. По Π·Π°Π²Π΅Ρ€ΡˆΠ΅Π½ΠΈΠΈ Π°Π΄Π°ΠΏΡ‚Π΅Ρ€ ΠΏΡƒΡˆΠΈΡ‚ΡΡ Π½Π° Ρ‚Π²ΠΎΠΉ ΠΏΡ€ΠΈΠ²Π°Ρ‚Π½Ρ‹ΠΉ HF-Ρ€Π΅ΠΏΠΎ.
 
190
 
191
  ```
192
  ru2sql/
193
+ β”œβ”€β”€ pyproject.toml # зависимости
194
  β”œβ”€β”€ .env.example # шаблон ΠΊΠΎΠ½Ρ„ΠΈΠ³ΡƒΡ€Π°Ρ†ΠΈΠΈ
195
+ β”œβ”€β”€ plan_VKR_text2sql_ru.md # ΠΏΠ»Π°Π½ Ρ€Π°Π±ΠΎΡ‚
196
  β”œβ”€β”€ notebooks/
197
  β”‚ └── kaggle_train_qwen_qlora.ipynb
198
+ β”œβ”€β”€ scripts/
199
+ β”‚ └── smoke_local.py # локальная ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° работоспособности
200
+ β”œβ”€β”€ configs/
201
+ β”‚ β”œβ”€β”€ example_vocabulary.yaml
202
+ β”‚ └── sales_vocabulary.yaml
203
  β”œβ”€β”€ src/
204
  β”‚ β”œβ”€β”€ config.py # настройки Ρ‡Π΅Ρ€Π΅Π· pydantic-settings
205
  β”‚ β”œβ”€β”€ data/
206
  β”‚ β”‚ β”œβ”€β”€ loader.py # Ρ‡Ρ‚Π΅Π½ΠΈΠ΅ PAUQ JSON
207
+ β”‚ β”‚ β”œβ”€β”€ schema_provider.py # SchemaProvider β€” Π΅Π΄ΠΈΠ½Ρ‹ΠΉ интСрфСйс
208
+ β”‚ β”‚ β”œβ”€β”€ schema.py # SchemaRetriever (фасад для PAUQ)
209
  β”‚ β”‚ └── prompt.py # PromptBuilder + chat-template
210
+ β”‚ β”œβ”€β”€ db/
211
+ β”‚ β”‚ β”œβ”€β”€ connector.py # DbConnector β€” Ρ‡Ρ‚Π΅Π½ΠΈΠ΅ схСм
212
+ β”‚ β”‚ └── executor.py # SqlExecutor с read-only
213
+ β”‚ β”œβ”€β”€ business/
214
+ β”‚ β”‚ └── vocabulary.py # BusinessVocabulary (YAML-ΠΊΠΎΠ½Ρ„ΠΈΠ³)
215
  β”‚ β”œβ”€β”€ models/
216
  β”‚ β”‚ β”œβ”€β”€ inference.py # InferenceEngine (модСль + LoRA)
217
+ β”‚ β”‚ └── postprocess.py # очистка SQL + guardrail
218
  β”‚ β”œβ”€β”€ evaluation/
219
+ β”‚ β”‚ β”œβ”€β”€ metrics.py # EM + Execution Accuracy
220
+ β”‚ β”‚ └── evaluate.py # CLI для ΠΏΡ€ΠΎΠ³ΠΎΠ½Π° Π½Π° split
221
  β”‚ └── api/
222
+ β”‚ β”œβ”€β”€ main.py # FastAPI app (5 эндпоинтов)
223
  β”‚ β”œβ”€β”€ schemas.py # Pydantic-ΠΌΠΎΠ΄Π΅Π»ΠΈ
224
  β”‚ └── dependencies.py # lifespan + DI
225
+ β”œβ”€οΏ½οΏ½ streamlit_app.py # UI (HTTPX-ΠΊΠ»ΠΈΠ΅Π½Ρ‚ ΠΊ API)
226
+ └── tests/ # 80+ тСстов
227
  β”œβ”€β”€ test_prompt.py
228
  β”œβ”€β”€ test_postprocess.py
229
  β”œβ”€β”€ test_metrics.py
230
+ β”œβ”€β”€ test_schema.py
231
+ β”œβ”€β”€ test_schema_provider.py
232
+ β”œβ”€β”€ test_vocabulary.py
233
+ └── test_db.py
234
  ```
235
 
236
  ---
 
249
 
250
  ---
251
 
252
+ ## ΠœΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ
253
 
254
  | МодСль | EM | Execution Accuracy |
255
  |---|---|---|
256
+ | ruT5-base (baseline) | 25–35 % | 30–40 % |
257
+ | **Qwen2.5-Coder-3B + QLoRA** | **40,0 %** | **71,9 %** |
258
+ | BRIDGE / RAT-SQL (PAUQ, mono) | 51 / 52 % | 48 / 49 % |
259
 
260
  ---
261
 
adapters/qwen-coder-pauq-lora/README.md CHANGED
@@ -1,199 +1,136 @@
1
  ---
2
- library_name: transformers
3
- tags: []
 
 
 
 
 
 
 
 
 
 
4
  ---
5
 
6
- # Model Card for Model ID
7
-
8
- <!-- Provide a quick summary of what the model is/does. -->
9
-
10
-
11
-
12
- ## Model Details
13
-
14
- ### Model Description
15
-
16
- <!-- Provide a longer summary of what this model is. -->
17
-
18
- This is the model card of a πŸ€— transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
-
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
-
28
- ### Model Sources [optional]
29
-
30
- <!-- Provide the basic links for the model. -->
31
-
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
35
-
36
- ## Uses
37
-
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
-
40
- ### Direct Use
41
-
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
-
44
- [More Information Needed]
45
-
46
- ### Downstream Use [optional]
47
-
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
-
50
- [More Information Needed]
51
-
52
- ### Out-of-Scope Use
53
-
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
-
56
- [More Information Needed]
57
-
58
- ## Bias, Risks, and Limitations
59
-
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
-
62
- [More Information Needed]
63
-
64
- ### Recommendations
65
-
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
-
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
-
70
- ## How to Get Started with the Model
71
-
72
- Use the code below to get started with the model.
73
-
74
- [More Information Needed]
75
-
76
- ## Training Details
77
-
78
- ### Training Data
79
-
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
-
82
- [More Information Needed]
83
-
84
- ### Training Procedure
85
-
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
-
88
- #### Preprocessing [optional]
89
-
90
- [More Information Needed]
91
-
92
-
93
- #### Training Hyperparameters
94
-
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
-
97
- #### Speeds, Sizes, Times [optional]
98
-
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
-
101
- [More Information Needed]
102
-
103
- ## Evaluation
104
-
105
- <!-- This section describes the evaluation protocols and provides the results. -->
106
-
107
- ### Testing Data, Factors & Metrics
108
-
109
- #### Testing Data
110
-
111
- <!-- This should link to a Dataset Card if possible. -->
112
-
113
- [More Information Needed]
114
-
115
- #### Factors
116
-
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
-
119
- [More Information Needed]
120
-
121
- #### Metrics
122
-
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
-
125
- [More Information Needed]
126
-
127
- ### Results
128
-
129
- [More Information Needed]
130
-
131
- #### Summary
132
-
133
-
134
-
135
- ## Model Examination [optional]
136
-
137
- <!-- Relevant interpretability work for the model goes here -->
138
-
139
- [More Information Needed]
140
-
141
- ## Environmental Impact
142
-
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
-
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
-
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
-
153
- ## Technical Specifications [optional]
154
-
155
- ### Model Architecture and Objective
156
-
157
- [More Information Needed]
158
-
159
- ### Compute Infrastructure
160
-
161
- [More Information Needed]
162
-
163
- #### Hardware
164
-
165
- [More Information Needed]
166
-
167
- #### Software
168
-
169
- [More Information Needed]
170
-
171
- ## Citation [optional]
172
-
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
-
175
- **BibTeX:**
176
-
177
- [More Information Needed]
178
-
179
- **APA:**
180
-
181
- [More Information Needed]
182
-
183
- ## Glossary [optional]
184
-
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
-
187
- [More Information Needed]
188
-
189
- ## More Information [optional]
190
-
191
- [More Information Needed]
192
-
193
- ## Model Card Authors [optional]
194
-
195
- [More Information Needed]
196
-
197
- ## Model Card Contact
198
-
199
- [More Information Needed]
 
1
  ---
2
+ library_name: peft
3
+ base_model: Qwen/Qwen2.5-Coder-3B-Instruct
4
+ language:
5
+ - ru
6
+ tags:
7
+ - text-to-sql
8
+ - qlora
9
+ - russian
10
+ - sql-generation
11
+ license: apache-2.0
12
+ datasets:
13
+ - ai-forever/PAUQ
14
  ---
15
 
16
+ # qwen-coder-pauq-lora
17
+
18
+ LoRA-Π°Π΄Π°ΠΏΡ‚Π΅Ρ€ для ΠΌΠΎΠ΄Π΅Π»ΠΈ **Qwen2.5-Coder-3B-Instruct**, ΠΎΠ±ΡƒΡ‡Π΅Π½Π½Ρ‹ΠΉ ΠΌΠ΅Ρ‚ΠΎΠ΄ΠΎΠΌ
19
+ **QLoRA** Π½Π° датасСтС **PAUQ** для Π·Π°Π΄Π°Ρ‡ΠΈ прСобразования вопросов Π½Π°
20
+ русском языкС Π² SQL-запросы. Π§Π°ΡΡ‚ΡŒ выпускной ΠΊΠ²Π°Π»ΠΈΡ„ΠΈΠΊΠ°Ρ†ΠΈΠΎΠ½Π½ΠΎΠΉ Ρ€Π°Π±ΠΎΡ‚Ρ‹
21
+ ΠΏΠΎ Π½Π°ΠΏΡ€Π°Π²Π»Π΅Π½ΠΈΡŽ 09.03.04 Β«ΠŸΡ€ΠΎΠ³Ρ€Π°ΠΌΠΌΠ½Π°Ρ инТСнСрия», КНИВУ-КАИ.
22
+
23
+ ## ОписаниС
24
+
25
+ АдаптСр ΠΎΠ±ΡƒΡ‡Π°Π΅Ρ‚ модСль ΠΎΡ‚Π²Π΅Ρ‡Π°Ρ‚ΡŒ Π½Π° русскоязычный аналитичСский вопрос
26
+ синтаксичСски ΠΊΠΎΡ€Ρ€Π΅ΠΊΡ‚Π½Ρ‹ΠΌ SQL-запросом, учитывая схСму ΠΊΠΎΠ½ΠΊΡ€Π΅Ρ‚Π½ΠΎΠΉ Π±Π°Π·Ρ‹
27
+ Π΄Π°Π½Π½Ρ‹Ρ…, ΠΏΠ΅Ρ€Π΅Π΄Π°Π½Π½ΡƒΡŽ Π² систСмном сообщСнии вмСстС с ΠΏΡ€ΠΈΠΌΠ΅Ρ€Π°ΠΌΠΈ строк.
28
+
29
+ Базовая модСль остаётся Π·Π°ΠΌΠΎΡ€ΠΎΠΆΠ΅Π½Π½ΠΎΠΉ, ΠΎΠ±ΡƒΡ‡Π°ΡŽΡ‚ΡΡ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ LoRA-ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹
30
+ Ρ€Π°Π½Π³ΠΎΠΌ 16, Π½Π°Π»ΠΎΠΆΠ΅Π½Π½Ρ‹Π΅ Π½Π° всС ΠΏΡ€ΠΎΠ΅ΠΊΡ†ΠΈΠΎΠ½Π½Ρ‹Π΅ слои attention ΠΈ MLP. Π­Ρ‚ΠΎ
31
+ позволяСт Ρ…Ρ€Π°Π½ΠΈΡ‚ΡŒ ΠΈ Ρ€Π°ΡΠΏΡ€ΠΎΡΡ‚Ρ€Π°Π½ΡΡ‚ΡŒ Π°Π΄Π°ΠΏΡ‚Π΅Ρ€ Ρ€Π°Π·ΠΌΠ΅Ρ€ΠΎΠΌ нСсколько дСсятков
32
+ ΠΌΠ΅Π³Π°Π±Π°ΠΉΡ‚, Π° Π½Π΅ ΠΏΠΎΠ»Π½Ρ‹Π΅ вСса ΠΌΠΎΠ΄Π΅Π»ΠΈ Π² нСсколько Π³ΠΈΠ³Π°Π±Π°ΠΉΡ‚.
33
+
34
+ ## ИспользованиС
35
+
36
+ ```python
37
+ from transformers import AutoModelForCausalLM, AutoTokenizer
38
+ from peft import PeftModel
39
+
40
+ base = "Qwen/Qwen2.5-Coder-3B-Instruct"
41
+ adapter = "Tyycha/qwen-coder-pauq-lora"
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(base)
44
+ model = AutoModelForCausalLM.from_pretrained(base, device_map="auto")
45
+ model = PeftModel.from_pretrained(model, adapter)
46
+ model.eval()
47
+
48
+ messages = [
49
+ {"role": "system", "content": "Π’Ρ‹ β€” ассистСнт, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΏΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΡƒΠ΅Ρ‚ вопросы Π½Π° русском языкС Π² SQL..."},
50
+ {"role": "user", "content": "### Schema:\nCREATE TABLE orders (id INT, amount REAL, status TEXT);\n\n### Question:\nКакая суммарная Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ° ΠΏΠΎ ΠΎΠΏΠ»Π°Ρ‡Π΅Π½Π½Ρ‹ΠΌ Π·Π°ΠΊΠ°Π·Π°ΠΌ?\n\n### SQL:\n"},
51
+ ]
52
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
54
+ output = model.generate(**inputs, max_new_tokens=256, do_sample=False)
55
+ print(tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
56
+ ```
57
+
58
+ ΠŸΠΎΠ»Π½Ρ‹ΠΉ pipeline (Ρ„ΠΎΡ€ΠΌΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ ΠΏΡ€ΠΎΠΌΠΏΡ‚Π°, постобработка, исполнСниС SQL)
59
+ Ρ€Π΅Π°Π»ΠΈΠ·ΠΎΠ²Π°Π½ Π² ΠΏΡ€ΠΎΠ΅ΠΊΡ‚Π΅ [Ru2SQL](https://github.com/Tyycha/Ru2SQL).
60
+
61
+ ## Π”Π°Π½Π½Ρ‹Π΅
62
+
63
+ - **ДатасСт:** PAUQ (Bakshandaeva et al., 2022) β€” ΠΏΠ΅Ρ€Π²Ρ‹ΠΉ ΠΊΡ€ΡƒΠΏΠ½Ρ‹ΠΉ русскоязычный
64
+ корпус для Π·Π°Π΄Π°Ρ‡ΠΈ Text-to-SQL, построСнный Π½Π° основС Spider.
65
+ - **Π Π°Π·ΠΌΠ΅Ρ€:** ~7 тыс. ΠΏΠ°Ρ€ (вопрос, SQL) Π² train, 1076 Π² dev.
66
+ - **Π‘Ρ‚Ρ€ΡƒΠΊΡ‚ΡƒΡ€Π° ΠΏΡ€ΠΈΠΌΠ΅Ρ€Π°:** Π΄ΠΈΠ°Π»ΠΎΠ³ Π² Ρ„ΠΎΡ€ΠΌΠ°Ρ‚Π΅ chat-template ΠΈΠ· Ρ‚Ρ€Ρ‘Ρ… сообщСний
67
+ (`system` с инструкциСй ΠΈ ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½Ρ‹ΠΌ бизнСс-словарём, `user` со
68
+ схСмой ΠΈ вопросом, `assistant` с эталонным SQL).
69
+
70
+ ## Π“ΠΈΠΏΠ΅Ρ€ΠΏΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ обучСния
71
+
72
+ | ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€ | Π—Π½Π°Ρ‡Π΅Π½ΠΈΠ΅ |
73
+ |---|---|
74
+ | Базовая модСль | Qwen/Qwen2.5-Coder-3B-Instruct |
75
+ | ΠœΠ΅Ρ‚ΠΎΠ΄ | QLoRA (NF4, double quantization, compute dtype = bfloat16) |
76
+ | LoRA rank `r` | 16 |
77
+ | LoRA alpha | 32 |
78
+ | LoRA dropout | 0.05 |
79
+ | Target modules | q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
80
+ | Π­ΠΏΠΎΡ…ΠΈ | 2 |
81
+ | Π Π°Π·ΠΌΠ΅Ρ€ Π±Π°Ρ‚Ρ‡Π° | 1 (Π½Π° устройство) |
82
+ | НакоплСниС Π³Ρ€Π°Π΄ΠΈΠ΅Π½Ρ‚ΠΎΠ² | 8 шагов (эффСктивный Π±Π°Ρ‚Ρ‡ 8) |
83
+ | Π‘ΠΊΠΎΡ€ΠΎΡΡ‚ΡŒ обучСния | 2e-4 (cosine, warmup 3 %) |
84
+ | Max sequence length | 1024 |
85
+ | Π’ΠΎΡ‡Π½ΠΎΡΡ‚ΡŒ | bfloat16 |
86
+ | ΠžΠ±ΠΎΡ€ΡƒΠ΄ΠΎΠ²Π°Π½ΠΈΠ΅ | NVIDIA Tesla T4 16 GB (Kaggle) |
87
+
88
+ ## ΠœΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ
89
+
90
+ ΠžΡ†Π΅Π½ΠΊΠ° Π½Π° Π²Π°Π»ΠΈΠ΄Π°Ρ†ΠΈΠΎΠ½Π½ΠΎΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠ΅ PAUQ (1076 ΠΏΡ€ΠΈΠΌΠ΅Ρ€ΠΎΠ², 166 Π±Π°Π· Π΄Π°Π½Π½Ρ‹Ρ…):
91
+
92
+ | ΠœΠ΅Ρ‚Ρ€ΠΈΠΊΠ° | Π—Π½Π°Ρ‡Π΅Π½ΠΈΠ΅ |
93
+ |---|---|
94
+ | Exact Match (EM) | 40.0 % (430 / 1076) |
95
+ | Execution Accuracy (EX) | 71.9 % (772 / 1074) |
96
+
97
+ Для сравнСния, ΠΎΠΏΡƒΠ±Π»ΠΈΠΊΠΎΠ²Π°Π½Π½Ρ‹Π΅ Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚Ρ‹ Π½Π° Ρ‚ΠΎΠΉ ΠΆΠ΅ Π²Ρ‹Π±ΠΎΡ€ΠΊΠ΅:
98
+
99
+ | МодСль | EM | EX |
100
+ |---|---|---|
101
+ | RAT-SQL (PAUQ, mono) | 51 % | 49 % |
102
+ | BRIDGE (PAUQ, mono) | 52 % | 48 % |
103
+ | **Qwen2.5-Coder-3B + QLoRA (эта Ρ€Π°Π±ΠΎΡ‚Π°)** | **40 %** | **71.9 %** |
104
+
105
+ Высокий Ρ€Π°Π·Ρ€Ρ‹Π² ΠΌΠ΅ΠΆΠ΄Ρƒ EM ΠΈ EX ΠΎΡ‚Ρ€Π°ΠΆΠ°Π΅Ρ‚ спСцифику Π·Π°Π΄Π°Ρ‡ΠΈ: модСль
106
+ Π³Π΅Π½Π΅Ρ€ΠΈΡ€ΡƒΠ΅Ρ‚ сСмантичСски ΠΊΠΎΡ€Ρ€Π΅ΠΊΡ‚Π½Ρ‹Π΅ запросы, Π½ΠΎ синтаксичСски
107
+ ΠΎΡ‚Π»ΠΈΡ‡Π°ΡŽΡ‰ΠΈΠ΅ΡΡ ΠΎΡ‚ эталонных (Π΄Ρ€ΡƒΠ³ΠΎΠΉ порядок условий, Π°Π»ΡŒΡ‚Π΅Ρ€Π½Π°Ρ‚ΠΈΠ²Π½Ρ‹Π΅
108
+ JOIN-стратСгии, ΠΈΠ½Ρ‹Π΅ псСвдонимы).
109
+
110
+ ## ΠžΠ³Ρ€Π°Π½ΠΈΡ‡Π΅Π½ΠΈΡ
111
+
112
+ - МодСль ΠΎΠ±ΡƒΡ‡Π΅Π½Π° Π½Π° схСмах PAUQ/Spider. На сущСствСнно ΠΎΡ‚Π»ΠΈΡ‡Π°ΡŽΡ‰ΠΈΡ…ΡΡ
113
+ ΠΏΡ€Π΅Π΄ΠΌΠ΅Ρ‚Π½Ρ‹Ρ… областях качСство ΠΌΠΎΠΆΠ΅Ρ‚ ΠΏΠ°Π΄Π°Ρ‚ΡŒ; для Π°Π΄Π°ΠΏΡ‚Π°Ρ†ΠΈΠΈ
114
+ прСдусмотрСн ΠΌΠ΅Ρ…Π°Π½ΠΈΠ·ΠΌ бизнСс-словаря Π² основном ΠΏΡ€ΠΎΠ΅ΠΊΡ‚Π΅.
115
+ - НаиболСС слоТныС классы запросов β€” Π²Π»ΠΎΠΆΠ΅Π½Π½Ρ‹Π΅ SELECT, конструкции
116
+ INTERSECT/EXCEPT/UNION, HAVING с нСсколькими условиями β€” ΠΎΡΡ‚Π°ΡŽΡ‚ΡΡ
117
+ слабым мСстом.
118
+ - ΠŸΡ€ΠΈ Ρ€Π°Π±ΠΎΡ‚Π΅ Π½Π° CPU Π±Π΅Π· GPU инфСрСнс Π·Π°Π½ΠΈΠΌΠ°Π΅Ρ‚ 15–30 сСкунд Π½Π° запрос.
119
+ - АдаптСр ΠΎΡ€ΠΈΠ΅Π½Ρ‚ΠΈΡ€ΠΎΠ²Π°Π½ Π½Π° SQLite. Для PostgreSQL/MySQL рСкомСндуСтся
120
+ явно ΡƒΠΊΠ°Π·Π°Ρ‚ΡŒ Π΄ΠΈΠ°Π»Π΅ΠΊΡ‚ Π² ΠΏΡ€ΠΎΠΌΠΏΡ‚Π΅.
121
+
122
+ ## ЛицСнзия
123
+
124
+ Apache 2.0. Базовая модСль Qwen2.5-Coder Ρ‚Π°ΠΊΠΆΠ΅ распространяСтся ΠΏΠΎΠ΄
125
+ Π»ΠΈΡ†Π΅Π½Π·ΠΈΠ΅ΠΉ Apache 2.0; датасСт PAUQ β€” Apache 2.0.
126
+
127
+ ## Π¦ΠΈΡ‚ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅
128
+
129
+ ```bibtex
130
+ @misc{ru2sql-qlora-2026,
131
+ title = {Ru2SQL: Russian Text-to-SQL via QLoRA on Qwen2.5-Coder},
132
+ author = {Siryazeev, Danis},
133
+ year = {2026},
134
+ note = {Bachelor's thesis, Kazan National Research Technical University},
135
+ }
136
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/demo/sales.sqlite-journal DELETED
Binary file (512 Bytes)
 
evaluate_pauq.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Π‘ΠΊΡ€ΠΈΠΏΡ‚ ΠΏΡ€ΠΎΠ³ΠΎΠ½Π° ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊ Ru2SQL Π½Π° Π²Π°Π»ΠΈΠ΄Π°Ρ†ΠΈΠΎΠ½Π½ΠΎΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠ΅ PAUQ.
3
+
4
+ Π‘Ρ‡ΠΈΡ‚Π°Π΅Ρ‚ Exact Match (EM) ΠΈ Execution Accuracy (EX) для Π΄ΠΎΠΎΠ±ΡƒΡ‡Π΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
5
+ Qwen2.5-Coder-3B + QLoRA. Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ Π½Π° Kaggle (T4 GPU), Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚
6
+ сохраняСтся Π² eval_results.csv ΠΈ eval_summary.json.
7
+
8
+ Π˜ΡΡ‚ΠΎΡ‡Π½ΠΈΠΊ Ρ†ΠΈΡ„Ρ€ для Ρ€Π°Π·Π΄Π΅Π»Π° 4.3 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки.
9
+
10
+ ИспользованиС на Kaggle:
11
+ 1. Π—Π°Π³Ρ€ΡƒΠ·ΠΈΡ‚ΡŒ ΠΏΡ€ΠΎΠ΅ΠΊΡ‚ Ρ†Π΅Π»ΠΈΠΊΠΎΠΌ (Ρ‡Π΅Ρ€Π΅Π· Kaggle Dataset ΠΈΠ»ΠΈ git clone).
12
+ 2. Π£ΡΡ‚Π°Π½ΠΎΠ²ΠΈΡ‚ΡŒ ADAPTER_ID Π½Π° свой HF-Ρ€Π΅ΠΏΠΎ.
13
+ 3. python evaluate_pauq.py
14
+
15
+ ΠŸΠΎΡΡ‚ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° ΠΈ нормализация SQL ΠΈΠΌΠΏΠΎΡ€Ρ‚ΠΈΡ€ΡƒΡŽΡ‚ΡΡ ΠΈΠ· src/models/postprocess.py,
16
+ Ρ‡Ρ‚ΠΎΠ±Ρ‹ ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ локально ΠΈ Π½Π° Kaggle ΠΎΡΡ‚Π°Π²Π°Π»ΠΈΡΡŒ сопоставимыми.
17
+ """
18
+
19
+ import sys
20
+ from pathlib import Path
21
+
22
+ # Π”Π΅Π»Π°Π΅ΠΌ ΠΏΠ°ΠΊΠ΅Ρ‚ src/ ΠΈΠΌΠΏΠΎΡ€Ρ‚ΠΈΡ€ΡƒΠ΅ΠΌΡ‹ΠΌ, ΠΊΠΎΠ³Π΄Π° скрипт запускаСтся ΠΈΠ· корня.
23
+ _PROJECT_ROOT = Path(__file__).resolve().parent
24
+ if str(_PROJECT_ROOT) not in sys.path:
25
+ sys.path.insert(0, str(_PROJECT_ROOT))
26
+
27
+ import csv
28
+ import json
29
+ import sqlite3
30
+
31
+ from tqdm import tqdm
32
+
33
+ import torch
34
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
35
+ from peft import PeftModel
36
+ from datasets import load_dataset
37
+
38
+ from src.models.postprocess import (
39
+ normalize_sql,
40
+ strip_model_artifacts as strip_artifacts,
41
+ )
42
+
43
+ # ─── CONFIG ───────────────────────────────────────────────────────────────────
44
+
45
+ BASE_MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct"
46
+ ADAPTER_ID = "Tyycha/qwen-coder-pauq-lora"
47
+ PAUQ_SPLIT = "validation" # 1034 ΠΏΡ€ΠΈΠΌΠ΅Ρ€Π°
48
+ MAX_NEW_TOKENS = 256
49
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
+ # ΠŸΡƒΡ‚ΡŒ ΠΊ ΠΏΠ°ΠΏΠΊΠ΅ с SQLite-Π±Π°Π·Π°ΠΌΠΈ PAUQ (Spider databases)
52
+ # На Kaggle: скачай https://drive.google.com/uc?id=1iRDVHLr4mX2wQKSgA9VxUUFpj-3-Kj5B
53
+ # Или ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠΉ датасСт: kaggle datasets download -d wikisql/spider
54
+ PAUQ_DB_DIR = Path("./pauq_databases") # ΠΏΠ°ΠΏΠΊΠ°, Π³Π΄Π΅ Π»Π΅ΠΆΠ°Ρ‚ ΠΏΠ°ΠΏΠΊΠΈ Π±Π°Π· Π΄Π°Π½Π½Ρ‹Ρ…
55
+
56
+ # ─── OUTPUT ───────────────────────────────────────────────────────────────────
57
+
58
+ RESULTS_FILE = "eval_results.csv"
59
+ SUMMARY_FILE = "eval_summary.json"
60
+
61
+ # ─── LOAD MODEL ───────────────────────────────────────────────────────────────
62
+
63
+ print("Loading model...")
64
+ bnb_config = BitsAndBytesConfig(
65
+ load_in_4bit=True,
66
+ bnb_4bit_quant_type="nf4",
67
+ bnb_4bit_compute_dtype=torch.bfloat16,
68
+ bnb_4bit_use_double_quant=True,
69
+ )
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
72
+ base_model = AutoModelForCausalLM.from_pretrained(
73
+ BASE_MODEL_ID,
74
+ quantization_config=bnb_config,
75
+ device_map="auto",
76
+ trust_remote_code=True
77
+ )
78
+ model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
79
+ model.eval()
80
+ print(f"Model loaded on {DEVICE}")
81
+
82
+ # ─── SCHEMA RETRIEVER ─────────────────────────────────────────────────────────
83
+
84
+ def get_schema(db_id: str) -> str:
85
+ """Extract CREATE TABLE statements from SQLite database."""
86
+ db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite"
87
+ if not db_path.exists():
88
+ return f"-- Database {db_id} not found"
89
+
90
+ conn = sqlite3.connect(str(db_path))
91
+ cursor = conn.cursor()
92
+
93
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
94
+ tables = [row[0] for row in cursor.fetchall()]
95
+
96
+ schema_parts = []
97
+ for table in tables:
98
+ cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table,))
99
+ create_sql = cursor.fetchone()
100
+ if create_sql and create_sql[0]:
101
+ schema_parts.append(create_sql[0])
102
+ # Add sample rows
103
+ try:
104
+ cursor.execute(f"SELECT * FROM \"{table}\" LIMIT 3")
105
+ rows = cursor.fetchall()
106
+ cursor.execute(f"PRAGMA table_info(\"{table}\")")
107
+ cols = [col[1] for col in cursor.fetchall()]
108
+ if rows:
109
+ schema_parts.append(f"-- Sample data for {table}:")
110
+ schema_parts.append("-- " + " | ".join(cols))
111
+ for row in rows[:3]:
112
+ schema_parts.append("-- " + " | ".join(str(v) for v in row))
113
+ except Exception:
114
+ pass
115
+
116
+ conn.close()
117
+ return "\n\n".join(schema_parts)
118
+
119
+ # ─── PROMPT BUILDER ───────────────────────────────────────────────────────────
120
+
121
+ SYSTEM_PROMPT = (
122
+ "Π’Ρ‹ β€” ассистСнт, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΏΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΡƒΠ΅Ρ‚ вопросы Π½Π° русском языкС Π² SQL-запросы. "
123
+ "ΠžΡ‚Π²Π΅Ρ‡Π°ΠΉ Π’ΠžΠ›Π¬ΠšΠž SQL-запросом Π±Π΅Π· объяснСний, ΠΊΠΎΠΌΠΌΠ΅Π½Ρ‚Π°Ρ€ΠΈΠ΅Π² ΠΈ markdown-Ρ€Π°Π·ΠΌΠ΅Ρ‚ΠΊΠΈ."
124
+ )
125
+
126
+ def build_prompt(question: str, schema: str) -> str:
127
+ user_msg = f"Schema:\n{schema}\n\nQuestion: {question}\n\nSQL:"
128
+ messages = [
129
+ {"role": "system", "content": SYSTEM_PROMPT},
130
+ {"role": "user", "content": user_msg},
131
+ ]
132
+ return tokenizer.apply_chat_template(
133
+ messages, tokenize=False, add_generation_prompt=True
134
+ )
135
+
136
+ # ─── SQL POSTPROCESSING ───────────────────────────────────────────────────────
137
+ # strip_artifacts ΠΈ normalize_sql ΠΈΠΌΠΏΠΎΡ€Ρ‚ΠΈΡ€ΡƒΡŽΡ‚ΡΡ ΠΈΠ· src/models/postprocess.py
138
+ # для Π³Π°Ρ€Π°Π½Ρ‚ΠΈΠΈ Ρ‚ΠΎΠ³ΠΎ, Ρ‡Ρ‚ΠΎ ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ Π½Π° Kaggle ΠΈ Π² Π»ΠΎΠΊΠ°Π»ΡŒΠ½Ρ‹Ρ… тСстах ΡΡ‡ΠΈΡ‚Π°ΡŽΡ‚ΡΡ
139
+ # ΠΏΠΎ ΠΎΠ΄Π½ΠΎΠΉ ΠΈ Ρ‚ΠΎΠΉ ΠΆΠ΅ Π»ΠΎΠ³ΠΈΠΊΠ΅.
140
+
141
+ # ─── EXECUTION ────────────────────────────────────────────────────────────────
142
+
143
+ def execute_sql(sql: str, db_id: str):
144
+ """Execute SQL and return result set as frozenset of tuples."""
145
+ db_path = PAUQ_DB_DIR / db_id / f"{db_id}.sqlite"
146
+ if not db_path.exists():
147
+ return None
148
+ try:
149
+ uri = f"file:{db_path}?mode=ro"
150
+ conn = sqlite3.connect(uri, uri=True)
151
+ cursor = conn.cursor()
152
+ cursor.execute(sql)
153
+ rows = cursor.fetchall()
154
+ conn.close()
155
+ return frozenset(tuple(str(v) for v in row) for row in rows)
156
+ except Exception:
157
+ return None
158
+
159
+ # ─── INFERENCE ────────────────────────────────────────────────────────────────
160
+
161
+ def generate_sql(question: str, schema: str) -> str:
162
+ prompt = build_prompt(question, schema)
163
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
164
+
165
+ with torch.no_grad():
166
+ outputs = model.generate(
167
+ **inputs,
168
+ max_new_tokens=MAX_NEW_TOKENS,
169
+ do_sample=False,
170
+ temperature=1.0,
171
+ pad_token_id=tokenizer.eos_token_id,
172
+ )
173
+
174
+ generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
175
+ return strip_artifacts(generated)
176
+
177
+ # ─── MAIN EVALUATION LOOP ─────────────────────────────────────────────────────
178
+
179
+ def main():
180
+ print(f"Loading PAUQ {PAUQ_SPLIT} split...")
181
+ dataset = load_dataset("ai-forever/PAUQ", split=PAUQ_SPLIT)
182
+ print(f"Loaded {len(dataset)} examples")
183
+
184
+ results = []
185
+ em_correct = 0
186
+ ex_correct = 0
187
+ ex_total = 0 # only count where execution was possible
188
+
189
+ with open(RESULTS_FILE, "w", newline="", encoding="utf-8") as f:
190
+ writer = csv.DictWriter(f, fieldnames=[
191
+ "idx", "db_id", "question", "gold_sql", "pred_sql",
192
+ "em", "ex", "error"
193
+ ])
194
+ writer.writeheader()
195
+
196
+ for idx, example in enumerate(tqdm(dataset, desc="Evaluating")):
197
+ question = example.get("question_ru") or example.get("question", "")
198
+ gold_sql = example.get("query", "")
199
+ db_id = example.get("db_id", "")
200
+
201
+ try:
202
+ schema = get_schema(db_id)
203
+ pred_sql = generate_sql(question, schema)
204
+
205
+ # Exact Match
206
+ norm_pred = normalize_sql(pred_sql)
207
+ norm_gold = normalize_sql(gold_sql)
208
+ em = int(norm_pred == norm_gold)
209
+
210
+ # Execution Accuracy
211
+ pred_result = execute_sql(pred_sql, db_id)
212
+ gold_result = execute_sql(gold_sql, db_id)
213
+
214
+ if gold_result is not None:
215
+ ex = int(pred_result == gold_result)
216
+ ex_correct += ex
217
+ ex_total += 1
218
+ else:
219
+ ex = None
220
+
221
+ em_correct += em
222
+ error = ""
223
+
224
+ except Exception as e:
225
+ pred_sql = ""
226
+ em = 0
227
+ ex = None
228
+ error = str(e)[:200]
229
+
230
+ row = {
231
+ "idx": idx, "db_id": db_id, "question": question,
232
+ "gold_sql": gold_sql, "pred_sql": pred_sql,
233
+ "em": em, "ex": ex, "error": error
234
+ }
235
+ writer.writerow(row)
236
+ results.append(row)
237
+
238
+ # Progress every 100 examples
239
+ if (idx + 1) % 100 == 0:
240
+ cur_em = em_correct / (idx + 1)
241
+ cur_ex = ex_correct / max(ex_total, 1)
242
+ print(f"[{idx+1}/{len(dataset)}] EM={cur_em:.3f} EX={cur_ex:.3f}")
243
+
244
+ # Final summary
245
+ n = len(dataset)
246
+ final_em = em_correct / n
247
+ final_ex = ex_correct / max(ex_total, 1)
248
+
249
+ summary = {
250
+ "model": ADAPTER_ID,
251
+ "split": PAUQ_SPLIT,
252
+ "n_examples": n,
253
+ "exact_match": round(final_em, 4),
254
+ "execution_accuracy": round(final_ex, 4),
255
+ "em_correct": em_correct,
256
+ "ex_correct": ex_correct,
257
+ "ex_total": ex_total
258
+ }
259
+
260
+ with open(SUMMARY_FILE, "w", encoding="utf-8") as f:
261
+ json.dump(summary, f, ensure_ascii=False, indent=2)
262
+
263
+ print("\n" + "="*50)
264
+ print(f"RESULTS on PAUQ {PAUQ_SPLIT} ({n} examples)")
265
+ print(f" Exact Match (EM): {final_em:.1%} ({em_correct}/{n})")
266
+ print(f" Execution Accuracy (EX): {final_ex:.1%} ({ex_correct}/{ex_total})")
267
+ print(f"\nDetailed results saved to: {RESULTS_FILE}")
268
+ print(f"Summary saved to: {SUMMARY_FILE}")
269
+
270
+ if __name__ == "__main__":
271
+ main()
scripts/run_app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Π›Π°ΡƒΠ½Ρ‡Π΅Ρ€ Π΄Π²ΡƒΡ… процСссов Ru2SQL: FastAPI + Streamlit.
2
+
3
+ Запуск:
4
+ python scripts/run_app.py
5
+
6
+ Π‘ΠΊΡ€ΠΈΠΏΡ‚ стартуСт uvicorn Π² Ρ„ΠΎΠ½Π΅, ΠΆΠ΄Ρ‘Ρ‚ ΠΏΠΎΠΊΠ° /health Π½Π°Ρ‡Π½Ρ‘Ρ‚ ΠΎΡ‚Π²Π΅Ρ‡Π°Ρ‚ΡŒ
7
+ (модСль Qwen ΠΌΠΎΠΆΠ΅Ρ‚ ΠΊΠ°Ρ‡Π°Ρ‚ΡŒΡΡ 5–10 ΠΌΠΈΠ½ΡƒΡ‚ ΠΏΡ€ΠΈ ΠΏΠ΅Ρ€Π²ΠΎΠΌ запускС), ΠΈ
8
+ ΠΏΠΎΠ΄Π½ΠΈΠΌΠ°Π΅Ρ‚ Streamlit ΠΊΠ°ΠΊ основной процСсс. По Ctrl+C ΠΊΠΎΡ€Ρ€Π΅ΠΊΡ‚Π½ΠΎ
9
+ останавливаСт ΠΎΠ±Π°.
10
+
11
+ ПолСзно ΠΏΡ€ΠΈ Ρ€Π°Π·Ρ€Π°Π±ΠΎΡ‚ΠΊΠ΅ ΠΈ для Π΄Π΅ΠΌΠΎ Π½Π° Π·Π°Ρ‰ΠΈΡ‚Π΅: ΠΎΠ΄ΠΈΠ½ Ctrl+C β€” ΠΈ ΠΎΠ±Π°
12
+ процСсса Π·Π°Π²Π΅Ρ€ΡˆΠ΅Π½Ρ‹, Π½Π΅ Π½ΡƒΠΆΠ½ΠΎ ΠΈΡΠΊΠ°Ρ‚ΡŒ висящий uvicorn Π² taskmgr.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import signal
19
+ import subprocess
20
+ import sys
21
+ import time
22
+ from pathlib import Path
23
+
24
+ import httpx
25
+
26
+ ROOT = Path(__file__).resolve().parent.parent
27
+ API_HOST = os.environ.get("API_HOST", "127.0.0.1")
28
+ API_PORT = int(os.environ.get("API_PORT", "8000"))
29
+ STREAMLIT_PORT = int(os.environ.get("STREAMLIT_PORT", "8501"))
30
+ # Streamlit ΡΠ»ΡƒΡˆΠ°Π΅Ρ‚ 0.0.0.0 Π² ΠΊΠΎΠ½Ρ‚Π΅ΠΉΠ½Π΅Ρ€Π΅ (Π½ΡƒΠΆΠ΅Π½ внСшний доступ),
31
+ # Π° локально ΠΏΠΎ ΡƒΠΌΠΎΠ»Ρ‡Π°Π½ΠΈΡŽ 127.0.0.1 β€” это управляСтся ΠΎΡ‚Π΄Π΅Π»ΡŒΠ½ΠΎΠΉ ΠΏΠ΅Ρ€Π΅ΠΌΠ΅Π½Π½ΠΎΠΉ.
32
+ STREAMLIT_HOST = os.environ.get("STREAMLIT_HOST", "127.0.0.1")
33
+ API_URL = f"http://{API_HOST}:{API_PORT}"
34
+ WAIT_API_SECONDS = 600 # модСль скачиваСтся Π΄ΠΎΠ»Π³ΠΎ
35
+
36
+
37
+ def wait_for_api(timeout: int) -> bool:
38
+ """ΠžΠΏΡ€Π°ΡˆΠΈΠ²Π°Π΅ΠΌ /health, ΠΏΠΎΠΊΠ° Π½Π΅ ΠΏΠΎΠ»ΡƒΡ‡ΠΈΠΌ 200 ΠΈΠ»ΠΈ Π½Π΅ Π²Ρ‹ΠΉΠ΄Π΅Ρ‚ timeout."""
39
+ deadline = time.time() + timeout
40
+ while time.time() < deadline:
41
+ try:
42
+ r = httpx.get(f"{API_URL}/health", timeout=2.0)
43
+ if r.status_code == 200:
44
+ return True
45
+ except Exception:
46
+ pass
47
+ time.sleep(1.5)
48
+ return False
49
+
50
+
51
+ def main():
52
+ print(f"[run_app] ΠΊΠΎΡ€Π΅Π½ΡŒ ΠΏΡ€ΠΎΠ΅ΠΊΡ‚Π°: {ROOT}")
53
+ print(f"[run_app] ΡΡ‚Π°Ρ€Ρ‚ΡƒΡŽ uvicorn Π½Π° {API_URL}")
54
+
55
+ creation_flags = 0
56
+ if sys.platform == "win32":
57
+ creation_flags = subprocess.CREATE_NEW_PROCESS_GROUP
58
+
59
+ uvicorn_proc = subprocess.Popen(
60
+ [
61
+ sys.executable, "-m", "uvicorn", "src.api.main:app",
62
+ "--host", API_HOST,
63
+ "--port", str(API_PORT),
64
+ ],
65
+ cwd=str(ROOT),
66
+ creationflags=creation_flags,
67
+ )
68
+
69
+ try:
70
+ print(f"[run_app] ΠΆΠ΄Ρƒ готовности API ({WAIT_API_SECONDS} сСк макс) β€” "
71
+ "ΠΏΡ€ΠΈ ΠΏΠ΅Ρ€Π²ΠΎΠΌ запускС Qwen качаСтся с HuggingFace")
72
+ if not wait_for_api(WAIT_API_SECONDS):
73
+ print("[run_app] API Π½Π΅ поднялся Π·Π° ΠΎΡ‚Π²Π΅Π΄Ρ‘Π½Π½ΠΎΠ΅ врСмя. ΠŸΡ€ΠΎΠ²Π΅Ρ€ΡŒ Π»ΠΎΠ³ΠΈ uvicorn.")
74
+ uvicorn_proc.terminate()
75
+ sys.exit(1)
76
+ print(f"[run_app] API Π³ΠΎΡ‚ΠΎΠ², ΡΡ‚Π°Ρ€Ρ‚ΡƒΡŽ Streamlit Π½Π° http://{STREAMLIT_HOST}:{STREAMLIT_PORT}")
77
+
78
+ env = os.environ.copy()
79
+ env["RU2SQL_API_URL"] = API_URL
80
+ streamlit_cmd = [
81
+ sys.executable, "-m", "streamlit", "run", "streamlit_app.py",
82
+ "--server.port", str(STREAMLIT_PORT),
83
+ "--server.address", STREAMLIT_HOST,
84
+ ]
85
+ streamlit_proc = subprocess.Popen(streamlit_cmd, cwd=str(ROOT), env=env)
86
+
87
+ # Π“Π»Π°Π²Π½Ρ‹ΠΉ Ρ†ΠΈΠΊΠ» β€” ΠΆΠ΄Ρ‘ΠΌ, ΠΏΠΎΠΊΠ° streamlit ΠΈΠ»ΠΈ uvicorn Π½Π΅ ΡƒΠΏΠ°Π΄Ρ‘Ρ‚
88
+ try:
89
+ while True:
90
+ if uvicorn_proc.poll() is not None:
91
+ print("[run_app] uvicorn Π·Π°Π²Π΅Ρ€ΡˆΠΈΠ»ΡΡ, ΠΎΡΡ‚Π°Π½Π°Π²Π»ΠΈΠ²Π°ΡŽ streamlit")
92
+ streamlit_proc.terminate()
93
+ break
94
+ if streamlit_proc.poll() is not None:
95
+ print("[run_app] streamlit Π·Π°Π²Π΅Ρ€ΡˆΠΈΠ»ΡΡ, ΠΎΡΡ‚Π°Π½Π°Π²Π»ΠΈΠ²Π°ΡŽ uvicorn")
96
+ break
97
+ time.sleep(1.0)
98
+ except KeyboardInterrupt:
99
+ print("\n[run_app] Ctrl+C ΠΏΠΎΠ»ΡƒΡ‡Π΅Π½, Π·Π°Π²Π΅Ρ€ΡˆΠ°ΡŽ процСссы…")
100
+ streamlit_proc.terminate()
101
+ finally:
102
+ try:
103
+ uvicorn_proc.terminate()
104
+ uvicorn_proc.wait(timeout=5)
105
+ except subprocess.TimeoutExpired:
106
+ uvicorn_proc.kill()
107
+ except Exception:
108
+ pass
109
+ print("[run_app] ΠΎΠ±Π° процСсса остановлСны")
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
scripts/smoke_local.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Π›ΠΎΠΊΠ°Π»ΡŒΠ½Ρ‹ΠΉ smoke-тСст работоспособности Ru2SQL.
2
+
3
+ ΠŸΡ€ΠΎΠ³ΠΎΠ½ΡΠ΅Ρ‚ ΠΊΠ»ΡŽΡ‡Π΅Π²Ρ‹Π΅ слои Π±Π΅Π· поднятия Streamlit/FastAPI ΠΊΠ°ΠΊ ΠΎΡ‚Π΄Π΅Π»ΡŒΠ½Ρ‹Ρ…
4
+ процСссов. ΠŸΠΎΠΊΡ€Ρ‹Π²Π°Π΅Ρ‚: ΠΈΠΌΠΏΠΎΡ€Ρ‚Ρ‹, demo-Π±Π°Π·Ρƒ, vocabulary, prompt,
5
+ постобработку, guardrail, Π½ΠΎΠ²Ρ‹ΠΉ API Ρ‡Π΅Ρ€Π΅Π· FastAPI TestClient ΠΈ
6
+ ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎ β€” ΠΊΠΎΡ€ΠΎΡ‚ΠΊΠΈΠΉ инфСрСнс Ρ€Π΅Π°Π»ΡŒΠ½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ.
7
+
8
+ Запуск:
9
+ python scripts/smoke_local.py
10
+ python scripts/smoke_local.py --with-model # ΠΌΠ΅Π΄Π»Π΅Π½Π½ΠΎ, Π³Ρ€ΡƒΠ·ΠΈΡ‚ Qwen
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import sys
17
+ import time
18
+ from pathlib import Path
19
+
20
+ ROOT = Path(__file__).resolve().parent.parent
21
+ sys.path.insert(0, str(ROOT))
22
+
23
+ GREEN = "\033[32m"
24
+ RED = "\033[31m"
25
+ YELLOW = "\033[33m"
26
+ RESET = "\033[0m"
27
+
28
+ results: list[tuple[str, str, str]] = []
29
+
30
+
31
+ def check(name: str):
32
+ def decorator(fn):
33
+ try:
34
+ t0 = time.time()
35
+ fn()
36
+ dt = time.time() - t0
37
+ results.append(("OK", name, f"{dt*1000:.0f} мс"))
38
+ except Exception as e:
39
+ results.append(("FAIL", name, f"{type(e).__name__}: {e}"))
40
+ return fn
41
+ return decorator
42
+
43
+
44
+ # ──────────────────────────────────────────────────────────────────────
45
+ # Π‘Π»ΠΎΠΉ 1 β€” ΠΈΠΌΠΏΠΎΡ€Ρ‚Ρ‹, ΠΊΠΎΠ½Ρ„ΠΈΠ³
46
+ # ──────────────────────────────────────────────────────────────────────
47
+ @check("Π˜ΠΌΠΏΠΎΡ€Ρ‚Ρ‹ ядра")
48
+ def _():
49
+ from src.config import settings
50
+ from src.api.schemas import GenerateRequest, QueryRequest, SchemaRequest
51
+ from src.business.vocabulary import BusinessVocabulary
52
+ from src.data.prompt import build_chat_messages, BASE_SYSTEM_PROMPT
53
+ from src.data.schema_provider import (
54
+ SpiderSchemaProvider, ConnectionSchemaProvider, TableSchema,
55
+ )
56
+ from src.db.connector import DbConnector
57
+ from src.db.executor import SqlExecutor
58
+ from src.models.postprocess import postprocess, is_select_only, is_valid_sql
59
+ assert settings.base_model_name.startswith("Qwen")
60
+
61
+
62
+ # ──────────────────────────────────────────────────────────────────────
63
+ # Π‘Π»ΠΎΠΉ 2 β€” Π΄Π΅ΠΌΠΎ-Π±Π°Π·Π° ΠΈ read-only guardrail
64
+ # ──────────────────────────────────────────────────────────────────────
65
+ @check("Π”Π΅ΠΌΠΎ-Π±Π°Π·Π° sales.sqlite: ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ ΠΈ схСма")
66
+ def _():
67
+ from src.db.connector import DbConnector
68
+ db = ROOT / "data" / "demo" / "sales.sqlite"
69
+ assert db.exists(), f"НС Π½Π°ΠΉΠ΄Π΅Π½ Ρ„Π°ΠΉΠ» {db} β€” запусти data/demo/create_demo_db.py"
70
+ conn = DbConnector(str(db))
71
+ tables = conn.list_tables()
72
+ assert set(tables) == {"customers", "managers", "products", "orders", "order_items"}
73
+
74
+
75
+ @check("Π”Π΅ΠΌΠΎ-Π±Π°Π·Π°: Ρ€Π΅Π°Π»ΡŒΠ½Ρ‹ΠΉ SELECT Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠΈ")
76
+ def _():
77
+ from src.db.executor import SqlExecutor
78
+ ex = SqlExecutor(str(ROOT / "data" / "demo" / "sales.sqlite"))
79
+ res = ex.run("SELECT SUM(amount) FROM orders WHERE status='paid'")
80
+ assert res.success, res.error
81
+ assert res.rows[0][0] > 0
82
+
83
+
84
+ @check("Read-only guardrail: DELETE отклоняСтся Π΄Ρ€Π°ΠΉΠ²Π΅Ρ€ΠΎΠΌ")
85
+ def _():
86
+ from src.db.executor import SqlExecutor
87
+ ex = SqlExecutor(str(ROOT / "data" / "demo" / "sales.sqlite"))
88
+ res = ex.run("DELETE FROM orders")
89
+ assert not res.success
90
+ assert ex.run("SELECT COUNT(*) FROM orders").rows[0][0] == 120
91
+
92
+
93
+ # ──────────────────────────────────────────────────────────────────────
94
+ # Π‘Π»ΠΎΠΉ 3 β€” BusinessVocabulary ΠΈΠ· YAML
95
+ # ──────────────────────────────────────────────────────────────────────
96
+ @check("BusinessVocabulary ΠΈΠ· configs/example_vocabulary.yaml")
97
+ def _():
98
+ from src.business.vocabulary import BusinessVocabulary
99
+ vocab = BusinessVocabulary.from_yaml(ROOT / "configs" / "example_vocabulary.yaml")
100
+ assert bool(vocab)
101
+ assert "Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°" in vocab.terms
102
+ assert "SUM(orders.amount)" in vocab.render_system_context()
103
+
104
+
105
+ # ──────────────────────────────────────────────────────────────────────
106
+ # Π‘Π»ΠΎΠΉ 4 β€” PromptBuilder с Ρ€Π΅Π°Π»ΡŒΠ½ΠΎΠΉ схСмой ΠΈ словарём
107
+ # ──────────────────────────────────────────────────────────────────────
108
+ @check("PromptBuilder: vocabulary ΡƒΡ…ΠΎΠ΄ΠΈΡ‚ Π² system, Π½Π΅ Π² user")
109
+ def _():
110
+ from src.business.vocabulary import BusinessVocabulary
111
+ from src.data.prompt import build_chat_messages
112
+ from src.db.connector import DbConnector
113
+ vocab = BusinessVocabulary.from_yaml(ROOT / "configs" / "example_vocabulary.yaml")
114
+ schema = DbConnector(str(ROOT / "data" / "demo" / "sales.sqlite")).render_schema()
115
+ msgs = build_chat_messages(schema, "Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ° Π·Π° 2026 Π³ΠΎΠ΄?", vocabulary=vocab)
116
+ assert "SUM(orders.amount)" in msgs[0]["content"]
117
+ assert "SUM(orders.amount)" not in msgs[1]["content"]
118
+
119
+
120
+ # ──────────────────────────────────────────────────────────────────────
121
+ # Π‘Π»ΠΎΠΉ 5 β€” постобработка ΠΈ guardrail
122
+ # ──────────────────────────────────────────────────────────────────────
123
+ @check("ΠŸΠΎΡΡ‚ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ°: markdown, прСфиксы, truncated, мусор")
124
+ def _():
125
+ from src.models.postprocess import postprocess, is_select_only
126
+ cases = [
127
+ ("```sql\nSELECT 1;\n```", lambda s: s.upper().startswith("SELECT")),
128
+ ("ΠžΡ‚Π²Π΅Ρ‚: SELECT name FROM customers;", lambda s: s.startswith("SELECT name")),
129
+ ("SELECT * FROM orders WHERE", lambda s: s == ""),
130
+ ("просто тСкст Π±Π΅Π· SQL", lambda s: s == ""),
131
+ ]
132
+ for raw, ok in cases:
133
+ assert ok(postprocess(raw)), f"Π‘Π±ΠΎΠΉ Π½Π°: {raw!r} β†’ {postprocess(raw)!r}"
134
+ assert is_select_only("SELECT 1") is True
135
+ assert is_select_only("DROP TABLE x") is False
136
+ assert is_select_only("DELETE FROM x") is False
137
+
138
+
139
+ # ──────────────────────────────────────────────────────────────────────
140
+ # Π‘Π»ΠΎΠΉ 6 β€” FastAPI Ρ‡Π΅Ρ€Π΅Π· TestClient с Π·Π°ΠΌΠΎΠΊΠ°Π½Π½Ρ‹ΠΌ InferenceEngine
141
+ # ──────────────────────────────────────────────────────────────────────
142
+ @check("FastAPI: /health, /schema, /query с ΠΏΠΎΠ΄ΠΌΠ΅Π½Ρ‘Π½Π½Ρ‹ΠΌ engine")
143
+ def _():
144
+ try:
145
+ from fastapi.testclient import TestClient
146
+ except ImportError:
147
+ raise AssertionError("ΠŸΠΎΡΡ‚Π°Π²ΡŒ fastapi[all] ΠΈΠ»ΠΈ httpx β€” TestClient нСдоступСн")
148
+
149
+ from src.api import dependencies as deps
150
+ from src.api.main import app
151
+ from src.models.inference import GenerationResult
152
+
153
+ class FakeEngine:
154
+ loaded = True
155
+ base_model_name = "Qwen/Qwen2.5-Coder-3B-Instruct"
156
+ def generate(self, schema, question, vocabulary=None, **kw):
157
+ # Π­ΠΌΡƒΠ»ΠΈΡ€ΡƒΠ΅ΠΌ Π²Π°Π»ΠΈΠ΄Π½Ρ‹ΠΉ SQL β€” pipeline Π΄ΠΎΠ»ΠΆΠ΅Π½ ΠΏΡ€ΠΎΠΏΡƒΡΡ‚ΠΈΡ‚ΡŒ Ρ‡Π΅Ρ€Π΅Π· guardrail.
158
+ sql = "SELECT SUM(amount) FROM orders WHERE status = 'paid'"
159
+ return GenerationResult(sql=sql, raw_output=sql)
160
+
161
+ app.dependency_overrides[deps.get_engine] = lambda: FakeEngine()
162
+ try:
163
+ with TestClient(app) as client:
164
+ # /health
165
+ r = client.get("/health")
166
+ assert r.status_code == 200
167
+ assert r.json()["model_loaded"] is True
168
+
169
+ # /schema Π½Π° Ρ€Π΅Π°Π»ΡŒΠ½ΠΎΠΉ demo-Π±Π°Π·Π΅
170
+ db_path = ROOT / "data" / "demo" / "sales.sqlite"
171
+ r = client.post("/schema", json={
172
+ "connection_string": str(db_path),
173
+ "include_samples": True,
174
+ })
175
+ assert r.status_code == 200, r.text
176
+ tables = r.json()["tables"]
177
+ assert {t["name"] for t in tables} == {
178
+ "customers", "managers", "products", "orders", "order_items"
179
+ }
180
+
181
+ # /query Π½Π° Ρ€Π΅Π°Π»ΡŒΠ½ΠΎΠΉ demo-Π±Π°Π·Π΅, FakeEngine отдаст Π²Π°Π»ΠΈΠ΄Π½Ρ‹ΠΉ SELECT
182
+ r = client.post("/query", json={
183
+ "question": "Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ° Π·Π° ΠΎΠΏΠ»Π°Ρ‡Π΅Π½Π½Ρ‹Π΅ Π·Π°ΠΊΠ°Π·Ρ‹?",
184
+ "connection_string": str(db_path),
185
+ "execute": True,
186
+ "vocabulary": {
187
+ "company": "Π”Π΅ΠΌΠΎ",
188
+ "terms": {"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(amount) WHERE status='paid'"},
189
+ },
190
+ })
191
+ assert r.status_code == 200, r.text
192
+ body = r.json()
193
+ assert body["is_valid_sql"] is True
194
+ assert body["execution"] is not None
195
+ assert body["execution"]["rows"][0][0] > 0
196
+
197
+ # /query с DELETE β€” Π΄ΠΎΠ»ΠΆΠ΅Π½ Π±Ρ‹Ρ‚ΡŒ ΠΎΡ‚Π±ΠΈΡ‚ guardrail'ΠΎΠΌ
198
+ class DropEngine(FakeEngine):
199
+ def generate(self, *a, **kw):
200
+ return GenerationResult(
201
+ sql="DELETE FROM orders WHERE id=1",
202
+ raw_output="DELETE FROM orders WHERE id=1",
203
+ )
204
+ app.dependency_overrides[deps.get_engine] = lambda: DropEngine()
205
+ r = client.post("/query", json={
206
+ "question": "Π£Π΄Π°Π»ΠΈ Π·Π°ΠΊΠ°Π· 1",
207
+ "connection_string": str(db_path),
208
+ "execute": True,
209
+ })
210
+ assert r.status_code == 200, r.text
211
+ body = r.json()
212
+ assert body["execution"] is None
213
+ assert body["error"] and "Π³Π²Π°Ρ€Π΄Π΅ΠΉΠ»" in body["error"].lower()
214
+ finally:
215
+ app.dependency_overrides.clear()
216
+
217
+
218
+ # ──────────────────────────────────────────────────────────────────────
219
+ # Π‘Π»ΠΎΠΉ 7 β€” ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½Ρ‹ΠΉ инфСрСнс Ρ€Π΅Π°Π»ΡŒΠ½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
220
+ # ──────────────────────────────────────────────────────────────────────
221
+ def run_model_smoke():
222
+ @check("InferenceEngine: Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈ ΠΎΠ΄Π½Π° гСнСрация")
223
+ def _():
224
+ from src.business.vocabulary import BusinessVocabulary
225
+ from src.db.connector import DbConnector
226
+ from src.models.inference import InferenceEngine
227
+
228
+ engine = InferenceEngine()
229
+ engine.load()
230
+ assert engine.loaded
231
+ schema = DbConnector(str(ROOT / "data" / "demo" / "sales.sqlite")).render_schema()
232
+ vocab = BusinessVocabulary.from_yaml(ROOT / "configs" / "example_vocabulary.yaml")
233
+ res = engine.generate(schema, "Какая суммарная Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ° Π·Π° 2026 Π³ΠΎΠ΄?", vocab)
234
+ assert res.sql and res.sql.upper().startswith("SELECT"), \
235
+ f"МодСль Π²Π΅Ρ€Π½ΡƒΠ»Π°: {res.raw_output!r}"
236
+
237
+
238
+ # ──────────────────────────────────────────────────────────────────────
239
+ # main
240
+ # ──────────────────────────────────────────────────────────────────────
241
+ def main():
242
+ parser = argparse.ArgumentParser()
243
+ parser.add_argument(
244
+ "--with-model", action="store_true",
245
+ help="Π”ΠΎΠΏΠΎΠ»Π½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎ ΠΏΡ€ΠΎΠ³Π½Π°Ρ‚ΡŒ Ρ€Π΅Π°Π»ΡŒΠ½Ρ‹ΠΉ инфСрСнс Qwen (ΠΌΠ΅Π΄Π»Π΅Π½Π½ΠΎ, ~30 сСк Π½Π° CPU)",
246
+ )
247
+ args = parser.parse_args()
248
+
249
+ if args.with_model:
250
+ run_model_smoke()
251
+
252
+ print()
253
+ print("=" * 64)
254
+ print("Smoke-ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Ru2SQL")
255
+ print("=" * 64)
256
+ for status, name, info in results:
257
+ color = GREEN if status == "OK" else RED
258
+ mark = "βœ“" if status == "OK" else "βœ—"
259
+ print(f" {color}{mark}{RESET} {name} {YELLOW}[{info}]{RESET}")
260
+ ok = sum(1 for s, _, _ in results if s == "OK")
261
+ print("=" * 64)
262
+ summary_color = GREEN if ok == len(results) else RED
263
+ print(f" {summary_color}{ok} / {len(results)} ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΎΠΊ ΠΏΡ€ΠΎΠΉΠ΄Π΅Π½ΠΎ{RESET}")
264
+ print("=" * 64)
265
+ if ok < len(results):
266
+ print()
267
+ print("Подсказка: запусти 'pytest -v' для ΠΏΠΎΠ΄Ρ€ΠΎΠ±Π½Ρ‹Ρ… диагностик.")
268
+ sys.exit(0 if ok == len(results) else 1)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
src/api/main.py CHANGED
@@ -3,11 +3,21 @@
3
  Запуск:
4
  uvicorn src.api.main:app --reload
5
  # Swagger UI: http://127.0.0.1:8000/docs
 
 
 
 
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
9
 
 
10
  import sqlite3
 
11
 
12
  from fastapi import Depends, FastAPI, HTTPException
13
  from fastapi.concurrency import run_in_threadpool
@@ -19,29 +29,59 @@ from src.api.schemas import (
19
  GenerateRequest,
20
  GenerateResponse,
21
  HealthResponse,
 
 
 
 
 
 
22
  )
 
23
  from src.config import settings
24
  from src.data.schema import SchemaRetriever
 
 
25
  from src.models.inference import InferenceEngine
26
- from src.models.postprocess import is_valid_sql
 
 
27
 
28
  app = FastAPI(
29
  title="ru2sql",
30
  description="ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ вопросов Π½Π° русском Π² SQL-запросы",
31
- version="0.1.0",
32
  lifespan=lifespan,
33
  )
34
 
35
 
 
 
 
 
36
  @app.get("/health", response_model=HealthResponse)
37
  def health(engine: InferenceEngine = Depends(get_engine)):
38
  return HealthResponse(
39
  status="ok",
40
- model_loaded=engine._loaded,
41
  base_model=engine.base_model_name,
42
  )
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @app.get("/databases", response_model=list[DatabaseInfo])
46
  def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
47
  out: list[DatabaseInfo] = []
@@ -54,21 +94,33 @@ def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
54
  return out
55
 
56
 
 
 
 
 
57
  @app.post("/generate-sql", response_model=GenerateResponse)
58
  async def generate_sql(
59
  req: GenerateRequest,
60
  engine: InferenceEngine = Depends(get_engine),
61
  retriever: SchemaRetriever = Depends(get_schema_retriever),
62
  ):
 
63
  try:
64
  schema_text = retriever.render_schema(req.db_id)
65
  except FileNotFoundError as e:
66
  raise HTTPException(status_code=404, detail=str(e)) from e
67
 
68
- # Inference синхронный ΠΈ тяТёлый β€” выносим Π² threadpool
69
- result = await run_in_threadpool(engine.generate, schema_text, req.question)
 
 
 
70
 
 
 
 
71
  valid = is_valid_sql(result.sql)
 
72
  response = GenerateResponse(
73
  sql=result.sql,
74
  raw_output=result.raw_output,
@@ -76,9 +128,15 @@ async def generate_sql(
76
  )
77
 
78
  if req.execute and valid:
 
 
 
 
 
 
79
  try:
80
  response.execution = await run_in_threadpool(
81
- _execute_sql, req.db_id, result.sql, retriever
82
  )
83
  except sqlite3.Error as e:
84
  response.error = f"SQL execution error: {e}"
@@ -86,7 +144,7 @@ async def generate_sql(
86
  return response
87
 
88
 
89
- def _execute_sql(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionResult:
90
  db_path = retriever.db_path(db_id)
91
  conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
92
  try:
@@ -95,16 +153,102 @@ def _execute_sql(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionR
95
  cur.execute(sql)
96
  rows = cur.fetchmany(100)
97
  cols = [d[0] for d in cur.description] if cur.description else []
98
- return ExecutionResult(
99
- columns=cols,
100
- rows=[list(r) for r in rows],
101
- row_count=len(rows),
102
- )
103
  finally:
104
  conn.close()
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if __name__ == "__main__":
108
  import uvicorn
109
-
110
  uvicorn.run("src.api.main:app", host=settings.api_host, port=settings.api_port, reload=True)
 
3
  Запуск:
4
  uvicorn src.api.main:app --reload
5
  # Swagger UI: http://127.0.0.1:8000/docs
6
+
7
+ Π­Π½Π΄ΠΏΠΎΠΈΠ½Ρ‚Ρ‹:
8
+ GET /health β€” статус сСрвиса ΠΈ Π·Π°Π³Ρ€ΡƒΠΆΠ΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
9
+ GET /databases β€” список Π‘Π” ΠΈΠ· data/databases (PAUQ-структура)
10
+ POST /generate-sql β€” гСнСрация SQL ΠΏΠΎ db_id ΠΈΠ· PAUQ
11
+ POST /schema β€” схСма ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” ΠΏΠΎ connection string
12
+ POST /query β€” ΠΏΠΎΠ»Π½Ρ‹ΠΉ pipeline для ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π”
13
+ (гСнСрация + ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎΠ΅ исполнСниС + guardrail)
14
  """
15
 
16
  from __future__ import annotations
17
 
18
+ import logging
19
  import sqlite3
20
+ import time
21
 
22
  from fastapi import Depends, FastAPI, HTTPException
23
  from fastapi.concurrency import run_in_threadpool
 
29
  GenerateRequest,
30
  GenerateResponse,
31
  HealthResponse,
32
+ QueryRequest,
33
+ QueryResponse,
34
+ SchemaRequest,
35
+ SchemaResponse,
36
+ TablePayload,
37
+ ColumnPayload,
38
  )
39
+ from src.business.vocabulary import BusinessVocabulary
40
  from src.config import settings
41
  from src.data.schema import SchemaRetriever
42
+ from src.data.schema_provider import ConnectionSchemaProvider
43
+ from src.db.executor import SqlExecutor
44
  from src.models.inference import InferenceEngine
45
+ from src.models.postprocess import is_select_only, is_valid_sql
46
+
47
+ logger = logging.getLogger(__name__)
48
 
49
  app = FastAPI(
50
  title="ru2sql",
51
  description="ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ вопросов Π½Π° русском Π² SQL-запросы",
52
+ version="0.2.0",
53
  lifespan=lifespan,
54
  )
55
 
56
 
57
+ # ──────────────────────────────────────────────────────────────────────
58
+ # Π‘Π°Π·ΠΎΠ²Ρ‹Π΅ эндпоинты
59
+ # ──────────────────────────────────────────────────────────────────────
60
+
61
  @app.get("/health", response_model=HealthResponse)
62
  def health(engine: InferenceEngine = Depends(get_engine)):
63
  return HealthResponse(
64
  status="ok",
65
+ model_loaded=engine.loaded,
66
  base_model=engine.base_model_name,
67
  )
68
 
69
 
70
+ @app.post("/warmup")
71
+ async def warmup(engine: InferenceEngine = Depends(get_engine)):
72
+ """ΠŸΡ€ΠΎΠ³Ρ€Π΅Π²Π°Π΅Ρ‚ модСль ΠΎΠ΄Π½ΠΎΠΉ ΠΊΠΎΡ€ΠΎΡ‚ΠΊΠΎΠΉ Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΠ΅ΠΉ.
73
+
74
+ ΠŸΠ΅Ρ€Π²Ρ‹ΠΉ инфСрСнс Ρ…ΠΎΠ»ΠΎΠ΄Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ Π½Π° CPU сильно дольшС ΠΏΠΎΡΠ»Π΅Π΄ΡƒΡŽΡ‰ΠΈΡ…:
75
+ ΠΏΠΎΠ΄Π³Ρ€ΡƒΠΆΠ°ΡŽΡ‚ΡΡ LoRA-слои, формируСтся Π³Ρ€Π°Ρ„ вычислСний, заполняСтся
76
+ KV-кСш. Π’Ρ‹Π·ΠΎΠ² /warmup Π΄Π΅Π»Π°Π΅Ρ‚ ΠΎΠ΄ΠΈΠ½ малСнький ΠΏΡ€ΠΎΡ…ΠΎΠ΄ с ΠΌΠΈΠ½ΠΈΠΌΠ°Π»ΡŒΠ½Ρ‹ΠΌ
77
+ max_new_tokens, Ρ‡Ρ‚ΠΎΠ±Ρ‹ Π±ΠΎΠ΅Π²ΠΎΠΉ /query ΡˆΡ‘Π» ΡƒΠΆΠ΅ ΠΏΠΎ ΠΏΡ€ΠΎΠ³Ρ€Π΅Ρ‚ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ.
78
+ """
79
+ t0 = time.time()
80
+ schema = "CREATE TABLE t (id INT);"
81
+ await run_in_threadpool(engine.generate, schema, "SELECT id", None, 16)
82
+ return {"warmup_seconds": round(time.time() - t0, 2)}
83
+
84
+
85
  @app.get("/databases", response_model=list[DatabaseInfo])
86
  def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
87
  out: list[DatabaseInfo] = []
 
94
  return out
95
 
96
 
97
+ # ──────────────────────────────────────────────────────────────────────
98
+ # PAUQ-сцСнарий (старый эндпоинт, оставлСн для совмСстимости)
99
+ # ──────────────────────────────────────────────────────────────────────
100
+
101
  @app.post("/generate-sql", response_model=GenerateResponse)
102
  async def generate_sql(
103
  req: GenerateRequest,
104
  engine: InferenceEngine = Depends(get_engine),
105
  retriever: SchemaRetriever = Depends(get_schema_retriever),
106
  ):
107
+ """ГСнСрация SQL для Π±Π°Π·Ρ‹ ΠΈΠ· PAUQ-структуры (data/databases/{db_id})."""
108
  try:
109
  schema_text = retriever.render_schema(req.db_id)
110
  except FileNotFoundError as e:
111
  raise HTTPException(status_code=404, detail=str(e)) from e
112
 
113
+ vocab = (
114
+ BusinessVocabulary.from_dict(req.vocabulary.model_dump())
115
+ if req.vocabulary
116
+ else None
117
+ )
118
 
119
+ result = await run_in_threadpool(
120
+ engine.generate, schema_text, req.question, vocab
121
+ )
122
  valid = is_valid_sql(result.sql)
123
+
124
  response = GenerateResponse(
125
  sql=result.sql,
126
  raw_output=result.raw_output,
 
128
  )
129
 
130
  if req.execute and valid:
131
+ if not is_select_only(result.sql):
132
+ response.error = (
133
+ "SQL ΠΎΡ‚ΠΊΠ»ΠΎΠ½Ρ‘Π½ Π³Π²Π°Ρ€Π΄Π΅ΠΉΠ»ΠΎΠΌ: Ρ€Π°Π·Ρ€Π΅ΡˆΠ΅Π½Ρ‹ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ запросы SELECT ΠΈ WITH."
134
+ )
135
+ logger.warning("Guardrail ΠΎΡ‚ΠΊΠ»ΠΎΠ½ΠΈΠ» SQL: %r", result.sql[:120])
136
+ return response
137
  try:
138
  response.execution = await run_in_threadpool(
139
+ _execute_sql_pauq, req.db_id, result.sql, retriever
140
  )
141
  except sqlite3.Error as e:
142
  response.error = f"SQL execution error: {e}"
 
144
  return response
145
 
146
 
147
+ def _execute_sql_pauq(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionResult:
148
  db_path = retriever.db_path(db_id)
149
  conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
150
  try:
 
153
  cur.execute(sql)
154
  rows = cur.fetchmany(100)
155
  cols = [d[0] for d in cur.description] if cur.description else []
156
+ return ExecutionResult(columns=cols, rows=[list(r) for r in rows], row_count=len(rows))
 
 
 
 
157
  finally:
158
  conn.close()
159
 
160
 
161
+ # ──────────────────────────────────────────────────────────────────────
162
+ # ΠŸΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½Π°Ρ Π‘Π” ΠΏΠΎ connection string β€” Π½ΠΎΠ²Ρ‹ΠΉ сцСнарий для Streamlit
163
+ # ──────────────────────────────────────────────────────────────────────
164
+
165
+ @app.post("/schema", response_model=SchemaResponse)
166
+ async def get_schema(req: SchemaRequest):
167
+ """Π’ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ схСму ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” для отобраТСния Π² ΠΊΠ»ΠΈΠ΅Π½Ρ‚Π΅."""
168
+ try:
169
+ provider = ConnectionSchemaProvider(req.connection_string)
170
+ tables = await run_in_threadpool(provider.get_tables, 2 if req.include_samples else 0)
171
+ except Exception as e: # noqa: BLE001
172
+ raise HTTPException(status_code=400, detail=f"Ошибка чтСния схСмы: {e}") from e
173
+
174
+ payload = [
175
+ TablePayload(
176
+ name=t.name,
177
+ columns=[
178
+ ColumnPayload(
179
+ name=c.name, type=c.type,
180
+ nullable=c.nullable, primary_key=c.primary_key,
181
+ )
182
+ for c in t.columns
183
+ ],
184
+ sample_rows=[list(r) for r in t.sample_rows],
185
+ ddl=t.to_ddl(),
186
+ )
187
+ for t in tables
188
+ ]
189
+ return SchemaResponse(tables=payload)
190
+
191
+
192
+ @app.post("/query", response_model=QueryResponse)
193
+ async def query(
194
+ req: QueryRequest,
195
+ engine: InferenceEngine = Depends(get_engine),
196
+ ):
197
+ """ΠŸΠΎΠ»Π½Ρ‹ΠΉ pipeline: вопрос β†’ SQL β†’ ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎΠ΅ исполнСниС Π½Π° Π‘Π”.
198
+
199
+ Π’ ΠΎΡ‚Π»ΠΈΡ‡ΠΈΠ΅ ΠΎΡ‚ /generate-sql, Ρ€Π°Π±ΠΎΡ‚Π°Π΅Ρ‚ с ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” ΠΏΠΎ connection
200
+ string. Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ Streamlit-ΠΊΠ»ΠΈΠ΅Π½Ρ‚ΠΎΠΌ ΠΈ сторонними интСграциями.
201
+ ΠŸΠ΅Ρ€Π΅Π΄ Π²Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ΠΌ SQL ΠΏΡ€ΠΎΡ…ΠΎΠ΄ΠΈΡ‚ ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΡƒ is_select_only (Ρ€Π°Π·Π΄Π΅Π» 4.4).
202
+ """
203
+ # 1. Π‘Ρ…Π΅ΠΌΠ° Ρ†Π΅Π»Π΅Π²ΠΎΠΉ Π‘Π”
204
+ try:
205
+ provider = ConnectionSchemaProvider(req.connection_string)
206
+ schema_text = await run_in_threadpool(provider.render_schema, True)
207
+ except Exception as e: # noqa: BLE001
208
+ raise HTTPException(status_code=400, detail=f"Ошибка ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ ΠΊ Π‘Π”: {e}") from e
209
+
210
+ vocab = (
211
+ BusinessVocabulary.from_dict(req.vocabulary.model_dump())
212
+ if req.vocabulary
213
+ else None
214
+ )
215
+
216
+ # 2. Π˜Π½Ρ„Π΅Ρ€Π΅Π½Ρ
217
+ t0 = time.time()
218
+ result = await run_in_threadpool(engine.generate, schema_text, req.question, vocab)
219
+ gen_time = time.time() - t0
220
+
221
+ valid = is_valid_sql(result.sql)
222
+ response = QueryResponse(
223
+ sql=result.sql,
224
+ raw_output=result.raw_output,
225
+ is_valid_sql=valid,
226
+ gen_time_seconds=round(gen_time, 2),
227
+ )
228
+
229
+ # 3. ΠžΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎΠ΅ исполнСниС
230
+ if req.execute and valid:
231
+ if not is_select_only(result.sql):
232
+ response.error = (
233
+ "SQL ΠΎΡ‚ΠΊΠ»ΠΎΠ½Ρ‘Π½ Π³Π²Π°Ρ€Π΄Π΅ΠΉΠ»ΠΎΠΌ: Ρ€Π°Π·Ρ€Π΅ΡˆΠ΅Π½Ρ‹ ��олько запросы SELECT ΠΈ WITH."
234
+ )
235
+ logger.warning("Guardrail ΠΎΡ‚ΠΊΠ»ΠΎΠ½ΠΈΠ» SQL: %r", result.sql[:120])
236
+ return response
237
+ try:
238
+ executor = SqlExecutor(req.connection_string)
239
+ qr = await run_in_threadpool(executor.run, result.sql)
240
+ if qr.success:
241
+ response.execution = ExecutionResult(
242
+ columns=qr.columns, rows=qr.rows, row_count=qr.row_count,
243
+ )
244
+ else:
245
+ response.error = f"SQL execution error: {qr.error}"
246
+ except Exception as e: # noqa: BLE001
247
+ response.error = f"SQL execution error: {e}"
248
+
249
+ return response
250
+
251
+
252
  if __name__ == "__main__":
253
  import uvicorn
 
254
  uvicorn.run("src.api.main:app", host=settings.api_host, port=settings.api_port, reload=True)
src/api/schemas.py CHANGED
@@ -5,10 +5,30 @@ from __future__ import annotations
5
  from pydantic import BaseModel, Field
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class GenerateRequest(BaseModel):
9
  question: str = Field(..., min_length=1, max_length=2000, description="Вопрос Π½Π° русском")
10
  db_id: str = Field(..., min_length=1, description="Π˜Π΄Π΅Π½Ρ‚ΠΈΡ„ΠΈΠΊΠ°Ρ‚ΠΎΡ€ Π‘Π” ΠΈΠ· PAUQ")
11
  execute: bool = Field(default=False, description="ΠŸΡ€ΠΎΠ³Π½Π°Ρ‚ΡŒ сгСнСрированный SQL Π½Π° Π‘Π”")
 
 
 
 
12
 
13
 
14
  class ExecutionResult(BaseModel):
@@ -25,6 +45,66 @@ class GenerateResponse(BaseModel):
25
  error: str | None = None
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class DatabaseInfo(BaseModel):
29
  db_id: str
30
  tables: list[str]
 
5
  from pydantic import BaseModel, Field
6
 
7
 
8
+ class VocabularyPayload(BaseModel):
9
+ """БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ Π² Ρ„ΠΎΡ€ΠΌΠ°Ρ‚Π΅, согласованном с BusinessVocabulary.from_dict.
10
+
11
+ ΠŸΠ΅Ρ€Π΅Π΄Π°Ρ‘Ρ‚ΡΡ ΠΊΠ»ΠΈΠ΅Π½Ρ‚ΠΎΠΌ ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎ Π² запросС Π½Π° Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΡŽ SQL. Если ΠΏΠΎΠ»Π΅
12
+ отсутствуСт β€” модСль Ρ€Π°Π±ΠΎΡ‚Π°Π΅Ρ‚ Π±Π΅Π· ΡΠΏΠ΅Ρ†ΠΈΠ°Π»ΡŒΠ½Ρ‹Ρ… бизнСс-ΠΎΠΏΡ€Π΅Π΄Π΅Π»Π΅Π½ΠΈΠΉ.
13
+ """
14
+ company: str = ""
15
+ terms: dict[str, str] = Field(default_factory=dict)
16
+ filters: dict[str, str] = Field(default_factory=dict)
17
+ notes: list[str] = Field(default_factory=list)
18
+
19
+
20
+ # ──────────────────────────────────────────────────────────────────────
21
+ # /generate-sql β€” старый эндпоинт для PAUQ-структуры (databases_dir + db_id)
22
+ # ──────────────────────────────────────────────────────────────────────
23
+
24
  class GenerateRequest(BaseModel):
25
  question: str = Field(..., min_length=1, max_length=2000, description="Вопрос Π½Π° русском")
26
  db_id: str = Field(..., min_length=1, description="Π˜Π΄Π΅Π½Ρ‚ΠΈΡ„ΠΈΠΊΠ°Ρ‚ΠΎΡ€ Π‘Π” ΠΈΠ· PAUQ")
27
  execute: bool = Field(default=False, description="ΠŸΡ€ΠΎΠ³Π½Π°Ρ‚ΡŒ сгСнСрированный SQL Π½Π° Π‘Π”")
28
+ vocabulary: VocabularyPayload | None = Field(
29
+ default=None,
30
+ description="ΠžΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½Ρ‹ΠΉ бизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ (см. Ρ€Π°Π·Π΄Π΅Π» 3.6 Π’ΠšΠ )",
31
+ )
32
 
33
 
34
  class ExecutionResult(BaseModel):
 
45
  error: str | None = None
46
 
47
 
48
+ # ──────────────────────────────────────────────────────────────────────
49
+ # /query β€” Π½ΠΎΠ²Ρ‹ΠΉ эндпоинт для ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” (connection string)
50
+ # ──────────────────────────────────────────────────────────────────────
51
+
52
+ class QueryRequest(BaseModel):
53
+ """ΠŸΠΎΠ»Π½Ρ‹ΠΉ запрос «вопрос Π½Π° русском β†’ SQL β†’ Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚Β» для ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π”.
54
+
55
+ Π’ ΠΎΡ‚Π»ΠΈΡ‡ΠΈΠ΅ ΠΎΡ‚ GenerateRequest, Π½Π΅ привязан ΠΊ PAUQ-структурС: ΠΊΠ»ΠΈΠ΅Π½Ρ‚ сам
56
+ ΠΏΠ΅Ρ€Π΅Π΄Π°Ρ‘Ρ‚ connection string (SQLite/PostgreSQL/MySQL). Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ
57
+ Streamlit-интСрфСйсом ΠΈ Π»ΡŽΠ±Ρ‹ΠΌΠΈ сторонними ΠΊΠ»ΠΈΠ΅Π½Ρ‚Π°ΠΌΠΈ.
58
+ """
59
+ question: str = Field(..., min_length=1, max_length=2000)
60
+ connection_string: str = Field(
61
+ ..., min_length=1,
62
+ description="Π‘Ρ‚Ρ€ΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ. ΠŸΡ€ΠΈΠΌΠ΅Ρ€: sqlite:///data/demo/sales.sqlite",
63
+ )
64
+ execute: bool = Field(default=True)
65
+ vocabulary: VocabularyPayload | None = None
66
+
67
+
68
+ class QueryResponse(BaseModel):
69
+ sql: str
70
+ raw_output: str
71
+ is_valid_sql: bool
72
+ gen_time_seconds: float
73
+ execution: ExecutionResult | None = None
74
+ error: str | None = None
75
+
76
+
77
+ # ──────────────────────────────────────────────────────────────────────
78
+ # /schema β€” ΠΎΡ‚Π΄Π°Ρ‚ΡŒ схСму ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” для подстановки Π² UI
79
+ # ──────────────────────────────────────────────────────────────────────
80
+
81
+ class SchemaRequest(BaseModel):
82
+ connection_string: str = Field(..., min_length=1)
83
+ include_samples: bool = Field(default=True)
84
+
85
+
86
+ class ColumnPayload(BaseModel):
87
+ name: str
88
+ type: str
89
+ nullable: bool
90
+ primary_key: bool
91
+
92
+
93
+ class TablePayload(BaseModel):
94
+ name: str
95
+ columns: list[ColumnPayload]
96
+ sample_rows: list[list]
97
+ ddl: str
98
+
99
+
100
+ class SchemaResponse(BaseModel):
101
+ tables: list[TablePayload]
102
+
103
+
104
+ # ──────────────────────────────────────────────────────────────────────
105
+ # ΠŸΡ€ΠΎΡ‡Π΅Π΅
106
+ # ──────────────────────────────────────────────────────────────────────
107
+
108
  class DatabaseInfo(BaseModel):
109
  db_id: str
110
  tables: list[str]
src/data/prompt.py CHANGED
@@ -1,8 +1,26 @@
1
- """PromptBuilder β€” Ρ„ΠΎΡ€ΠΌΠΈΡ€ΡƒΠ΅Ρ‚ input для ΠΌΠΎΠ΄Π΅Π»ΠΈ Π² Ρ„ΠΎΡ€ΠΌΠ°Ρ‚Π΅ chat-template."""
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
- SYSTEM_PROMPT = (
 
 
 
 
 
 
6
  "Π’Ρ‹ β€” ассистСнт, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΏΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΡƒΠ΅Ρ‚ вопросы Π½Π° русском языкС Π² ΠΊΠΎΡ€Ρ€Π΅ΠΊΡ‚Π½Ρ‹Π΅ SQL-запросы. "
7
  "Π’Π΅Π±Π΅ даётся схСма Π±Π°Π·Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ… Π² Π²ΠΈΠ΄Π΅ CREATE TABLE statements ΠΈ ΠΏΡ€ΠΈΠΌΠ΅Ρ€ Π½Π΅ΡΠΊΠΎΠ»ΡŒΠΊΠΈΡ… строк. "
8
  "Π‘Π³Π΅Π½Π΅Ρ€ΠΈΡ€ΡƒΠΉ ΠΎΠ΄ΠΈΠ½ SQL-запрос, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΎΡ‚Π²Π΅Ρ‡Π°Π΅Ρ‚ Π½Π° вопрос ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Ρ. "
@@ -11,19 +29,63 @@ SYSTEM_PROMPT = (
11
 
12
 
13
  def build_user_message(schema: str, question: str) -> str:
 
14
  return f"### Schema:\n{schema}\n\n### Question:\n{question}\n\n### SQL:\n"
15
 
16
 
17
- def build_chat_messages(schema: str, question: str) -> list[dict]:
18
- """Π€ΠΎΡ€ΠΌΠ°Ρ‚ для tokenizer.apply_chat_template."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  return [
20
- {"role": "system", "content": SYSTEM_PROMPT},
21
  {"role": "user", "content": build_user_message(schema, question)},
22
  ]
23
 
24
 
25
- def build_training_example(schema: str, question: str, sql: str) -> list[dict]:
26
- """ΠŸΠΎΠ»Π½Ρ‹ΠΉ Π΄ΠΈΠ°Π»ΠΎΠ³ для SFT с ΠΎΡ‚Π²Π΅Ρ‚ΠΎΠΌ ассистСнта."""
27
- msgs = build_chat_messages(schema, question)
 
 
 
 
 
 
 
28
  msgs.append({"role": "assistant", "content": sql.strip()})
29
  return msgs
 
 
 
 
 
 
1
+ """PromptBuilder β€” Ρ„ΠΎΡ€ΠΌΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ chat-template input для ΠΌΠΎΠ΄Π΅Π»ΠΈ.
2
+
3
+ БоотвСтствуСт Ρ€Π°Π·Π΄Π΅Π»Ρƒ 2.4 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки. Один ΠΈ Ρ‚ΠΎΡ‚ ΠΆΠ΅ Π±ΠΈΠ»Π΄Π΅Ρ€
4
+ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ ΠΏΡ€ΠΈ ΠΎΠ±ΡƒΡ‡Π΅Π½ΠΈΠΈ (Ρ„ΠΎΡ€ΠΌΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ SFT-ΠΏΡ€ΠΈΠΌΠ΅Ρ€ΠΎΠ²) ΠΈ ΠΏΡ€ΠΈ инфСрСнсС
5
+ (Ρ„ΠΎΡ€ΠΌΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ запроса ΠΊ Π·Π°Π³Ρ€ΡƒΠΆΠ΅Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ), Ρ‡Ρ‚ΠΎ Π³Π°Ρ€Π°Π½Ρ‚ΠΈΡ€ΡƒΠ΅Ρ‚ совпадСниС
6
+ Ρ„ΠΎΡ€ΠΌΠ°Ρ‚Π° train- ΠΈ inference-time ΠΏΡ€ΠΎΠΌΠΏΡ‚ΠΎΠ².
7
+
8
+ Помимо схСмы ΠΈ вопроса, Π±ΠΈΠ»Π΄Π΅Ρ€ ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎ ΠΏΡ€ΠΈΠ½ΠΈΠΌΠ°Π΅Ρ‚ BusinessVocabulary
9
+ (Ρ€Π°Π·Π΄Π΅Π» 3.6 Π’ΠšΠ ). БизнСс-Ρ‚Π΅Ρ€ΠΌΠΈΠ½Ρ‹ ΠΏΠΎΠ΄ΠΌΠ΅ΡˆΠΈΠ²Π°ΡŽΡ‚ΡΡ Π² систСмноС сообщСниС,
10
+ Π° Π½Π΅ ΠΊΠΎΠ½ΠΊΠ°Ρ‚Π΅Π½ΠΈΡ€ΡƒΡŽΡ‚ΡΡ ΠΊ ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»ΡŒΡΠΊΠΎΠΌΡƒ вопросу β€” это согласуСтся с
11
+ Ρ‚Π΅ΠΌ, ΠΊΠ°ΠΊ соврСмСнныС instruction-tuned ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈΠ½Ρ‚Π΅Ρ€ΠΏΡ€Π΅Ρ‚ΠΈΡ€ΡƒΡŽΡ‚ Ρ€ΠΎΠ»ΠΈ
12
+ сообщСний Π² chat-template.
13
+ """
14
 
15
  from __future__ import annotations
16
 
17
+ from typing import TYPE_CHECKING
18
+
19
+ if TYPE_CHECKING:
20
+ from src.business.vocabulary import BusinessVocabulary
21
+
22
+
23
+ BASE_SYSTEM_PROMPT = (
24
  "Π’Ρ‹ β€” ассистСнт, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΏΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΡƒΠ΅Ρ‚ вопросы Π½Π° русском языкС Π² ΠΊΠΎΡ€Ρ€Π΅ΠΊΡ‚Π½Ρ‹Π΅ SQL-запросы. "
25
  "Π’Π΅Π±Π΅ даётся схСма Π±Π°Π·Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ… Π² Π²ΠΈΠ΄Π΅ CREATE TABLE statements ΠΈ ΠΏΡ€ΠΈΠΌΠ΅Ρ€ Π½Π΅ΡΠΊΠΎΠ»ΡŒΠΊΠΈΡ… строк. "
26
  "Π‘Π³Π΅Π½Π΅Ρ€ΠΈΡ€ΡƒΠΉ ΠΎΠ΄ΠΈΠ½ SQL-запрос, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΎΡ‚Π²Π΅Ρ‡Π°Π΅Ρ‚ Π½Π° вопрос ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Ρ. "
 
29
 
30
 
31
  def build_user_message(schema: str, question: str) -> str:
32
+ """ΠŸΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»ΡŒΡΠΊΠ°Ρ Ρ‡Π°ΡΡ‚ΡŒ ΠΏΡ€ΠΎΠΌΠΏΡ‚Π° Π² Ρ„ΠΎΡ€ΠΌΠ°Ρ‚Π΅ ``### Schema / ### Question / ### SQL:``."""
33
  return f"### Schema:\n{schema}\n\n### Question:\n{question}\n\n### SQL:\n"
34
 
35
 
36
+ def build_system_message(vocabulary: "BusinessVocabulary | None" = None) -> str:
37
+ """Π‘ΠΎΠ±ΠΈΡ€Π°Π΅Ρ‚ систСмноС сообщСниС.
38
+
39
+ Если ΠΏΠ΅Ρ€Π΅Π΄Π°Π½ нСпустой бизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ, ΠΊ Π±Π°Π·ΠΎΠ²ΠΎΠΌΡƒ ΠΏΡ€ΠΎΠΌΠΏΡ‚Ρƒ добавляСтся
40
+ Π±Π»ΠΎΠΊ с опрСдСлСниями Ρ‚Π΅Ρ€ΠΌΠΈΠ½ΠΎΠ², Ρ„ΠΈΠ»ΡŒΡ‚Ρ€Π°ΠΌΠΈ ΠΈ ΠΏΡ€Π°Π²ΠΈΠ»Π°ΠΌΠΈ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ. Π­Ρ‚ΠΎ
41
+ позволяСт Π°Π΄Π°ΠΏΡ‚ΠΈΡ€ΠΎΠ²Π°Ρ‚ΡŒ систСму ΠΊ Ρ‚Π΅Ρ€ΠΌΠΈΠ½ΠΎΠ»ΠΎΠ³ΠΈΠΈ ΠΊΠΎΠ½ΠΊΡ€Π΅Ρ‚Π½ΠΎΠΉ ΠΎΡ€Π³Π°Π½ΠΈΠ·Π°Ρ†ΠΈΠΈ
42
+ Π±Π΅Π· ΠΏΠΎΠ²Ρ‚ΠΎΡ€Π½ΠΎΠ³ΠΎ дообучСния ΠΌΠΎΠ΄Π΅Π»ΠΈ.
43
+ """
44
+ if vocabulary is None or not vocabulary:
45
+ return BASE_SYSTEM_PROMPT
46
+ context = vocabulary.render_system_context()
47
+ if not context:
48
+ return BASE_SYSTEM_PROMPT
49
+ return BASE_SYSTEM_PROMPT + "\n\n" + context
50
+
51
+
52
+ def build_chat_messages(
53
+ schema: str,
54
+ question: str,
55
+ vocabulary: "BusinessVocabulary | None" = None,
56
+ ) -> list[dict]:
57
+ """БообщСния для ``tokenizer.apply_chat_template``.
58
+
59
+ ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹
60
+ ---------
61
+ schema : str
62
+ ВСкстовоС прСдставлСниС схСмы (CREATE TABLE + sample rows).
63
+ question : str
64
+ Вопрос ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Ρ Π½Π° русском языкС.
65
+ vocabulary : BusinessVocabulary, optional
66
+ БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ. Если ΠΏΠ΅Ρ€Π΅Π΄Π°Π½ β€” добавляСтся Π² систСмноС
67
+ сообщСниС, Π½Π΅ Π½Π°Ρ€ΡƒΡˆΠ°Ρ структуры ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»ΡŒΡΠΊΠΎΠΉ Ρ€Π΅ΠΏΠ»ΠΈΠΊΠΈ.
68
+ """
69
  return [
70
+ {"role": "system", "content": build_system_message(vocabulary)},
71
  {"role": "user", "content": build_user_message(schema, question)},
72
  ]
73
 
74
 
75
+ def build_training_example(
76
+ schema: str,
77
+ question: str,
78
+ sql: str,
79
+ vocabulary: "BusinessVocabulary | None" = None,
80
+ ) -> list[dict]:
81
+ """ΠŸΠΎΠ»Π½Ρ‹ΠΉ Π΄ΠΈΠ°Π»ΠΎΠ³ с эталонной Ρ€Π΅ΠΏΠ»ΠΈΠΊΠΎΠΉ ассистСнта для Supervised
82
+ Fine-Tuning (Ρ€Π°Π·Π΄Π΅Π» 2.4 Π’ΠšΠ ).
83
+ """
84
+ msgs = build_chat_messages(schema, question, vocabulary)
85
  msgs.append({"role": "assistant", "content": sql.strip()})
86
  return msgs
87
+
88
+
89
+ # ΠžΠ±Ρ€Π°Ρ‚Π½Π°Ρ ΡΠΎΠ²ΠΌΠ΅ΡΡ‚ΠΈΠΌΠΎΡΡ‚ΡŒ. Имя SYSTEM_PROMPT использовалось Π² старом ΠΊΠΎΠ΄Π΅
90
+ # ΠΈ Π² тСстах; сохраняСм алиас, Ρ‡Ρ‚ΠΎΠ±Ρ‹ Π½Π΅ Π»ΠΎΠΌΠ°Ρ‚ΡŒ ΠΈΠΌΠΏΠΎΡ€Ρ‚Ρ‹.
91
+ SYSTEM_PROMPT = BASE_SYSTEM_PROMPT
src/data/schema.py CHANGED
@@ -1,76 +1,27 @@
1
- """SchemaRetriever β€” ΠΈΠ·Π²Π»Π΅ΠΊΠ°Π΅Ρ‚ DDL ΠΈ ΠΏΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ строк ΠΈΠ· SQLite-Ρ„Π°ΠΉΠ»ΠΎΠ² PAUQ/Spider."""
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
- import sqlite3
6
- from dataclasses import dataclass
7
  from pathlib import Path
8
 
 
9
 
10
- @dataclass
11
- class TableInfo:
12
- name: str
13
- create_sql: str
14
- sample_rows: list[tuple]
15
 
16
 
17
- class SchemaRetriever:
18
- """Π§ΠΈΡ‚Π°Π΅Ρ‚ структуру SQLite-Π‘Π” для ΠΏΠΎΠ΄Π°Ρ‡ΠΈ Π² prompt ΠΌΠΎΠ΄Π΅Π»ΠΈ."""
19
 
20
  def __init__(self, databases_dir: Path | str):
21
- self.databases_dir = Path(databases_dir)
22
-
23
- def db_path(self, db_id: str) -> Path:
24
- """Π’ Spider/PAUQ каТдая Π‘Π” Π»Π΅ΠΆΠΈΡ‚ Π² databases_dir/{db_id}/{db_id}.sqlite."""
25
- path = self.databases_dir / db_id / f"{db_id}.sqlite"
26
- if not path.exists():
27
- raise FileNotFoundError(f"Database file not found: {path}")
28
- return path
29
-
30
- def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableInfo]:
31
- """Π’ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ список Ρ‚Π°Π±Π»ΠΈΡ† с CREATE-SQL ΠΈ ΠΏΡ€ΠΈΠΌΠ΅Ρ€ΠΎΠΌ строк."""
32
- path = self.db_path(db_id)
33
- conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
34
- try:
35
- conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
36
- cur = conn.cursor()
37
- cur.execute(
38
- "SELECT name, sql FROM sqlite_master "
39
- "WHERE type='table' AND name NOT LIKE 'sqlite_%'"
40
- )
41
- rows = cur.fetchall()
42
-
43
- tables: list[TableInfo] = []
44
- for table_name, create_sql in rows:
45
- if not create_sql:
46
- continue
47
- try:
48
- cur.execute(f'SELECT * FROM "{table_name}" LIMIT {n_sample_rows}')
49
- samples = cur.fetchall()
50
- except sqlite3.Error:
51
- samples = []
52
- tables.append(
53
- TableInfo(name=table_name, create_sql=create_sql.strip(), sample_rows=samples)
54
- )
55
- return tables
56
- finally:
57
- conn.close()
58
-
59
- def render_schema(self, db_id: str, include_samples: bool = True) -> str:
60
- """ВСкстовоС прСдставлСниС схСмы для prompt'Π°."""
61
- tables = self.get_tables(db_id)
62
- parts: list[str] = []
63
- for t in tables:
64
- parts.append(t.create_sql + ";")
65
- if include_samples and t.sample_rows:
66
- parts.append(f"-- ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ строк ΠΈΠ· {t.name}:")
67
- for row in t.sample_rows:
68
- parts.append(f"-- {row}")
69
- parts.append("")
70
- return "\n".join(parts).strip()
71
-
72
- def list_databases(self) -> list[str]:
73
- """Бписок доступных db_id."""
74
- if not self.databases_dir.exists():
75
- return []
76
- return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())
 
1
+ """SchemaRetriever β€” фасад Π½Π°Π΄ :class:`SpiderSchemaProvider`.
2
+
3
+ Π˜ΡΡ‚ΠΎΡ€ΠΈΡ‡Π΅ΡΠΊΠΈ здСсь ΠΆΠΈΠ»Π° полная рСализация чтСния PAUQ/Spider-схСм. ПослС
4
+ Ρ€Π΅Ρ„Π°ΠΊΡ‚ΠΎΡ€ΠΈΠ½Π³Π° (Ρ€Π°Π·Π΄Π΅Π» 4.2 Π°ΡƒΠ΄ΠΈΡ‚Π°) общая Π»ΠΎΠ³ΠΈΠΊΠ° ΠΏΠ΅Ρ€Π΅Π΅Ρ…Π°Π»Π° Π²
5
+ ``schema_provider.py``, Π° ``SchemaRetriever`` оставлСн ΠΊΠ°ΠΊ тонкая
6
+ ΠΎΠ±Ρ‘Ρ€Ρ‚ΠΊΠ° Ρ€Π°Π΄ΠΈ ΠΎΠ±Ρ€Π°Ρ‚Π½ΠΎΠΉ совмСстимости ΠΈΠΌΠΏΠΎΡ€Ρ‚ΠΎΠ² (``from src.data.schema
7
+ import SchemaRetriever``), ΠΊΠΎΡ‚ΠΎΡ€Ρ‹Π΅ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΡŽΡ‚ΡΡ Π² API ΠΈ тСстах.
8
+
9
+ Новый ΠΊΠΎΠ΄ стоит ΠΏΠΈΡΠ°Ρ‚ΡŒ сразу Ρ‡Π΅Ρ€Π΅Π· :class:`SpiderSchemaProvider`.
10
+ """
11
 
12
  from __future__ import annotations
13
 
 
 
14
  from pathlib import Path
15
 
16
+ from src.data.schema_provider import SpiderSchemaProvider, TableSchema
17
 
18
+ # Алиас для совмСстимости со старыми ΠΈΠΌΠΏΠΎΡ€Ρ‚Π°ΠΌΠΈ Π²ΠΈΠ΄Π°
19
+ # ``from src.data.schema import TableInfo``.
20
+ TableInfo = TableSchema
 
 
21
 
22
 
23
+ class SchemaRetriever(SpiderSchemaProvider):
24
+ """БовмСстимый алиас :class:`SpiderSchemaProvider`."""
25
 
26
  def __init__(self, databases_dir: Path | str):
27
+ super().__init__(databases_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/schema_provider.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Π•Π΄ΠΈΠ½Ρ‹ΠΉ интСрфСйс Ρ€Π°Π±ΠΎΡ‚Ρ‹ со схСмами Π±Π°Π· Π΄Π°Π½Π½Ρ‹Ρ….
2
+
3
+ Π”ΠΎ Ρ€Π΅Ρ„Π°ΠΊΡ‚ΠΎΡ€ΠΈΠ½Π³Π° Π² ΠΏΡ€ΠΎΠ΅ΠΊΡ‚Π΅ сущСствовали Π΄Π²Π΅ нСзависимыС ΠΈΠ΅Ρ€Π°Ρ€Ρ…ΠΈΠΈ:
4
+
5
+ * ``SchemaRetriever`` (``src/data/schema.py``) β€” Ρ‡ΠΈΡ‚Π°Π» DDL ΠΈΠ· SQLite-Ρ„Π°ΠΉΠ»ΠΎΠ²
6
+ Π² Spider/PAUQ-структурС ``{databases_dir}/{db_id}/{db_id}.sqlite``.
7
+ * ``DbConnector`` (``src/db/connector.py``) β€” ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π°Π»ΡΡ ΠΊ ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π”
8
+ ΠΏΠΎ строкС ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ, ΡƒΠΌΠ΅Π» SQLite/PostgreSQL/MySQL.
9
+
10
+ Они Ρ€Π΅ΡˆΠ°Π»ΠΈ ΠΎΠ΄Π½Ρƒ Π·Π°Π΄Π°Ρ‡Ρƒ, Π½ΠΎ ΠΏΠΎ-Ρ€Π°Π·Π½ΠΎΠΌΡƒ оформляли Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚
11
+ (``TableInfo`` Π² ΠΊΠ°ΠΆΠ΄ΠΎΠΌ Π±Ρ‹Π» свой) ΠΈ Π½Π΅ ΠΈΠΌΠ΅Π»ΠΈ ΠΎΠ±Ρ‰Π΅Π³ΠΎ интСрфСйса. Π­Ρ‚ΠΎΡ‚
12
+ ΠΌΠΎΠ΄ΡƒΠ»ΡŒ Π²Π²ΠΎΠ΄ΠΈΡ‚ Π΅Π΄ΠΈΠ½Ρ‹ΠΉ ΠΏΡ€ΠΎΡ‚ΠΎΠΊΠΎΠ» ``SchemaProvider`` ΠΈ ΠΎΠ±Ρ‰ΠΈΠΉ dataclass
13
+ ``TableSchema``. Π‘Ρ‚Π°Ρ€Ρ‹Π΅ классы становятся Ρ‚ΠΎΠ½ΠΊΠΈΠΌΠΈ фасадами ΠΏΠΎΠ²Π΅Ρ€Ρ…
14
+ Π½ΠΎΠ²Ρ‹Ρ… Ρ€Π΅Π°Π»ΠΈΠ·Π°Ρ†ΠΈΠΉ.
15
+
16
+ БоотвСтствуСт Ρ€Π°Π·Π΄Π΅Π»Π°ΠΌ 3.4 ΠΈ 4.1 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from dataclasses import dataclass, field
22
+ from pathlib import Path
23
+ from typing import Iterable, Protocol
24
+
25
+
26
+ @dataclass
27
+ class ColumnSchema:
28
+ """ОписаниС ΠΊΠΎΠ»ΠΎΠ½ΠΊΠΈ Ρ‚Π°Π±Π»ΠΈΡ†Ρ‹."""
29
+
30
+ name: str
31
+ type: str
32
+ nullable: bool = True
33
+ primary_key: bool = False
34
+
35
+
36
+ @dataclass
37
+ class TableSchema:
38
+ """Π£Π½ΠΈΡ„ΠΈΡ†ΠΈΡ€ΠΎΠ²Π°Π½Π½ΠΎΠ΅ описаниС Ρ‚Π°Π±Π»ΠΈΡ†Ρ‹ нСзависимо ΠΎΡ‚ источника схСмы.
39
+
40
+ ПолС ``create_sql`` Ρ…Ρ€Π°Π½ΠΈΡ‚ исходный CREATE TABLE statement, Ссли ΠΎΠ½
41
+ доступСн (Π°ΠΊΡ‚ΡƒΠ°Π»ΡŒΠ½ΠΎ для SQLite β€” ΠΎΠ½ Π΅Π³ΠΎ сам ΠΎΡ‚Π΄Π°Ρ‘Ρ‚ ΠΈΠ· ``sqlite_master``).
42
+ Когда источник схСмы β€” PostgreSQL/MySQL, DDL гСнСрируСтся ΠΈΠ·
43
+ ΠΌΠ΅Ρ‚Π°Π΄Π°Π½Π½Ρ‹Ρ… Ρ‡Π΅Ρ€Π΅Π· :meth:`to_ddl`.
44
+ """
45
+
46
+ name: str
47
+ columns: list[ColumnSchema] = field(default_factory=list)
48
+ sample_rows: list[tuple] = field(default_factory=list)
49
+ create_sql: str | None = None
50
+
51
+ def to_ddl(self) -> str:
52
+ """CREATE TABLE для подстановки Π² ΠΏΡ€ΠΎΠΌΠΏΡ‚.
53
+
54
+ Если Π΅ΡΡ‚ΡŒ ΠΎΡ€ΠΈΠ³ΠΈΠ½Π°Π»ΡŒΠ½Ρ‹ΠΉ ``create_sql`` β€” Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅ΠΌ Π΅Π³ΠΎ, Ρ‡Ρ‚ΠΎΠ±Ρ‹
55
+ ΡΠΎΡ…Ρ€Π°Π½ΠΈΡ‚ΡŒ всС Π½ΡŽΠ°Π½ΡΡ‹ (ограничСния, FK, AUTOINCREMENT). Π˜Π½Π°Ρ‡Π΅
56
+ собираСм ΠΈΠ· ΠΌΠ΅Ρ‚Π°Π΄Π°Π½Π½Ρ‹Ρ… ΠΊΠΎΠ»ΠΎΠ½ΠΎΠΊ.
57
+ """
58
+ if self.create_sql:
59
+ return self.create_sql.rstrip(";") + ";"
60
+ col_parts: list[str] = []
61
+ for col in self.columns:
62
+ line = f" {col.name} {col.type}"
63
+ if col.primary_key:
64
+ line += " PRIMARY KEY"
65
+ if not col.nullable:
66
+ line += " NOT NULL"
67
+ col_parts.append(line)
68
+ return f"CREATE TABLE {self.name} (\n" + ",\n".join(col_parts) + "\n);"
69
+
70
+
71
+ class SchemaProvider(Protocol):
72
+ """ΠŸΡ€ΠΎΡ‚ΠΎΠΊΠΎΠ» любого источника схСмы Π±Π°Π·Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ….
73
+
74
+ ΠšΠΎΠ½Ρ‚Ρ€Π°ΠΊΡ‚ ΠΌΠΈΠ½ΠΈΠΌΠ°Π»ΡŒΠ½Ρ‹ΠΉ: ΡƒΠΌΠ΅Ρ‚ΡŒ ΠΏΠ΅Ρ€Π΅Ρ‡ΠΈΡΠ»ΠΈΡ‚ΡŒ Ρ‚Π°Π±Π»ΠΈΡ†Ρ‹ ΠΈ ΠΎΡ‚Ρ€Π΅Π½Π΄Π΅Ρ€ΠΈΡ‚ΡŒ схСму
75
+ Π² тСкст для подстановки Π² ΠΏΡ€ΠΎΠΌΠΏΡ‚. Π­Ρ‚ΠΎΠ³ΠΎ достаточно ΠΈ для PAUQ-сцСнария
76
+ (``SpiderSchemaProvider``), ΠΈ для ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ ΠΊ Π±ΠΎΠ΅Π²ΠΎΠΉ Π‘Π” ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Ρ
77
+ (``ConnectionSchemaProvider``).
78
+ """
79
+
80
+ def list_tables(self) -> list[str]: ...
81
+ def get_tables(self, n_sample_rows: int = 3) -> list[TableSchema]: ...
82
+ def render_schema(self, include_samples: bool = True) -> str: ...
83
+
84
+
85
+ # ──────────────────────────────────────────────────────────────────────
86
+ # Π£Ρ‚ΠΈΠ»ΠΈΡ‚Π° Ρ€Π΅Π½Π΄Π΅Ρ€ΠΈΠ½Π³Π° β€” общая для всСх Ρ€Π΅Π°Π»ΠΈΠ·Π°Ρ†ΠΈΠΉ
87
+ # ──────────────────────────────────────────────────────────────────────
88
+
89
+ def render_tables(tables: Iterable[TableSchema], include_samples: bool = True) -> str:
90
+ """Π‘ΠΎΠ±ΠΈΡ€Π°Π΅Ρ‚ тСкстовоС прСдставлСниС списка Ρ‚Π°Π±Π»ΠΈΡ† для ΠΏΡ€ΠΎΠΌΠΏΡ‚Π°."""
91
+ parts: list[str] = []
92
+ for t in tables:
93
+ parts.append(t.to_ddl())
94
+ if include_samples and t.sample_rows:
95
+ parts.append(f"-- ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ строк ΠΈΠ· {t.name}:")
96
+ for row in t.sample_rows:
97
+ parts.append(f"-- {row}")
98
+ parts.append("")
99
+ return "\n".join(parts).strip()
100
+
101
+
102
+ # ────────────────────────────────────────────────────────────────────��─
103
+ # РСализация 1 β€” Spider/PAUQ-структура
104
+ # ──────────────────────────────────────────────────────────────────────
105
+
106
+ class SpiderSchemaProvider:
107
+ """SchemaProvider для ΠΊΠ°Ρ‚Π°Π»ΠΎΠ³Π° ``data/databases/{db_id}/{db_id}.sqlite``.
108
+
109
+ Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ ΠΏΡ€ΠΈ Ρ€Π°Π±ΠΎΡ‚Π΅ с PAUQ/Spider: каТдая Π‘Π” Π»Π΅ΠΆΠΈΡ‚ Π² ΠΎΠ΄Π½ΠΎΠΈΠΌΡ‘Π½Π½ΠΎΠΉ
110
+ ΠΏΠ°ΠΏΠΊΠ΅. Один экзСмпляр SpiderSchemaProvider обслуТиваСт всю ΠΊΠΎΠ»Π»Π΅ΠΊΡ†ΠΈΡŽ
111
+ Π±Π°Π· β€” конкрСтная Π‘Π” выбираСтся ΠΏΠΎ ``db_id`` Π² ΠΌΠ΅Ρ‚ΠΎΠ΄Π°Ρ….
112
+ """
113
+
114
+ def __init__(self, databases_dir: Path | str):
115
+ self.databases_dir = Path(databases_dir)
116
+
117
+ def list_databases(self) -> list[str]:
118
+ if not self.databases_dir.exists():
119
+ return []
120
+ return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())
121
+
122
+ def db_path(self, db_id: str) -> Path:
123
+ path = self.databases_dir / db_id / f"{db_id}.sqlite"
124
+ if not path.exists():
125
+ raise FileNotFoundError(f"Database file not found: {path}")
126
+ return path
127
+
128
+ def for_database(self, db_id: str) -> "ConnectionSchemaProvider":
129
+ """Π’ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ SchemaProvider, привязанный ΠΊ ΠΊΠΎΠ½ΠΊΡ€Π΅Ρ‚Π½ΠΎΠΉ Π‘Π”."""
130
+ return ConnectionSchemaProvider(f"sqlite:///{self.db_path(db_id)}")
131
+
132
+ # ── Π‘ΠΎΠ²ΠΌΠ΅ΡΡ‚ΠΈΠΌΠΎΡΡ‚ΡŒ с ΠΏΡ€Π΅Π΄Ρ‹Π΄ΡƒΡ‰ΠΈΠΌ API SchemaRetriever ────────────────
133
+
134
+ def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableSchema]:
135
+ return self.for_database(db_id).get_tables(n_sample_rows=n_sample_rows)
136
+
137
+ def render_schema(self, db_id: str, include_samples: bool = True) -> str:
138
+ return self.for_database(db_id).render_schema(include_samples=include_samples)
139
+
140
+
141
+ # ──────────────────────────────────────────────────────────────────────
142
+ # РСализация 2 β€” ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½Π°Ρ Π‘Π” ΠΏΠΎ connection string
143
+ # ──────────────────────────────────────────────────────────────────────
144
+
145
+ class ConnectionSchemaProvider:
146
+ """SchemaProvider для ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” (SQLite/PostgreSQL/MySQL).
147
+
148
+ Π”Π΅Π»Π΅Π³ΠΈΡ€ΡƒΠ΅Ρ‚ Ρ‡Ρ‚Π΅Π½ΠΈΠ΅ DbConnector'Ρƒ, Π½ΠΎ Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ ΠΎΠ±ΡŠΠ΅ΠΊΡ‚Ρ‹ Π΅Π΄ΠΈΠ½ΠΎΠ³ΠΎ Ρ‚ΠΈΠΏΠ°
149
+ ``TableSchema``. Π­Ρ‚ΠΎ Π½ΡƒΠΆΠ½ΠΎ, Ρ‡Ρ‚ΠΎΠ±Ρ‹ ΠΎΠ΄ΠΈΠ½ ΠΈ Ρ‚ΠΎΡ‚ ΠΆΠ΅ ΠΊΠΎΠ΄ Π² API ΠΈ Streamlit
150
+ ΠΌΠΎΠ³ Ρ€Π°Π±ΠΎΡ‚Π°Ρ‚ΡŒ ΠΊΠ°ΠΊ с PAUQ-структурой, Ρ‚Π°ΠΊ ΠΈ с Π±ΠΎΠ΅Π²ΠΎΠΉ Π‘Π” ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Ρ.
151
+ """
152
+
153
+ def __init__(self, connection_string: str, n_sample_rows: int = 2):
154
+ # Π˜ΠΌΠΏΠΎΡ€Ρ‚ здСсь, Ρ‡Ρ‚ΠΎΠ±Ρ‹ ΠΈΠ·Π±Π΅ΠΆΠ°Ρ‚ΡŒ ΠΊΠΎΠ»ΡŒΡ†Π΅Π²ΠΎΠΉ зависимости
155
+ # (db.connector β†’ data.schema_provider Π² случаС фасада).
156
+ from src.db.connector import DbConnector
157
+ self._connector = DbConnector(connection_string, n_sample_rows=n_sample_rows)
158
+ self.connection_string = self._connector.connection_string
159
+
160
+ # ── Π‘Π°Π·ΠΎΠ²Ρ‹Π΅ ΠΎΠΏΠ΅Ρ€Π°Ρ†ΠΈΠΈ SchemaProvider ───────────────────────────────
161
+
162
+ def list_tables(self) -> list[str]:
163
+ return self._connector.list_tables()
164
+
165
+ def get_tables(self, n_sample_rows: int = 3) -> list[TableSchema]:
166
+ # DbConnector Π² Ρ‚Π΅ΠΊΡƒΡ‰Π΅ΠΉ Ρ€Π΅Π°Π»ΠΈΠ·Π°Ρ†ΠΈΠΈ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ свой n_sample_rows ΠΈΠ· ctor;
167
+ # для совмСстимости с ΠΏΡ€ΠΎΡ‚ΠΎΠΊΠΎΠ»ΠΎΠΌ β€” ΠΈΠ³Π½ΠΎΡ€ΠΈΡ€ΡƒΠ΅ΠΌ ΠΏΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€ здСсь, довСряя
168
+ # настройкС ΠΊΠΎΠ½Π½Π΅ΠΊΡ‚ΠΎΡ€Π°. ΠŸΡ€ΠΈ ΠΆΠ΅Π»Π°Π½ΠΈΠΈ ΠΌΠΎΠΆΠ½ΠΎ завСсти setter.
169
+ raw = self._connector.get_schema(include_samples=n_sample_rows > 0)
170
+ return [
171
+ TableSchema(
172
+ name=t.name,
173
+ columns=[
174
+ ColumnSchema(
175
+ name=c.name, type=c.type,
176
+ nullable=c.nullable, primary_key=c.primary_key,
177
+ )
178
+ for c in t.columns
179
+ ],
180
+ sample_rows=list(t.sample_rows),
181
+ )
182
+ for t in raw
183
+ ]
184
+
185
+ def render_schema(self, include_samples: bool = True) -> str:
186
+ return render_tables(self.get_tables(n_sample_rows=2 if include_samples else 0),
187
+ include_samples=include_samples)
188
+
189
+ def test_connection(self) -> bool:
190
+ return self._connector.test_connection()
src/db/connector.py CHANGED
@@ -1,11 +1,11 @@
1
- """DbConnector -- podklyuchenie k proizvolnoy baze dannykh i chtenie skhemy.
2
 
3
- Podderzhivaemye tipy BD:
4
- SQLite -- put k faylu: "sqlite:///path/to/db.sqlite" ili prosto put
5
- PostgreSQL -- "postgresql://user:pass@host:port/dbname" (trebuet psycopg2)
6
- MySQL -- "mysql://user:pass@host:port/dbname" (trebuet pymysql)
7
 
8
- Primer:
9
  conn = DbConnector("sqlite:///data/demo/sales.sqlite")
10
  print(conn.render_schema())
11
  tables = conn.list_tables()
@@ -13,11 +13,14 @@ Primer:
13
 
14
  from __future__ import annotations
15
 
 
16
  import sqlite3
17
  from dataclasses import dataclass, field
18
  from pathlib import Path
19
  from urllib.parse import urlparse
20
 
 
 
21
 
22
  @dataclass
23
  class ColumnInfo:
@@ -34,7 +37,7 @@ class TableInfo:
34
  sample_rows: list[tuple] = field(default_factory=list)
35
 
36
  def to_ddl(self) -> str:
37
- """Generiruet CREATE TABLE statement iz metadannykh."""
38
  col_parts = []
39
  for col in self.columns:
40
  line = f" {col.name} {col.type}"
@@ -47,7 +50,7 @@ class TableInfo:
47
 
48
 
49
  class DbConnector:
50
- """Universalnyy konektor k BD. Umeet chitat skhemu dlya podstanovki v prompt."""
51
 
52
  def __init__(self, connection_string: str, n_sample_rows: int = 2):
53
  self.connection_string = self._normalize(connection_string)
@@ -66,7 +69,7 @@ class DbConnector:
66
  for t in tables:
67
  parts.append(t.to_ddl())
68
  if include_samples and t.sample_rows:
69
- parts.append(f"-- Primery strok iz {t.name}:")
70
  for row in t.sample_rows:
71
  parts.append(f"-- {row}")
72
  parts.append("")
@@ -76,7 +79,8 @@ class DbConnector:
76
  try:
77
  self._get_tables(n_sample_rows=0)
78
  return True
79
- except Exception:
 
80
  return False
81
 
82
  def _get_tables(self, n_sample_rows: int) -> list[TableInfo]:
@@ -87,11 +91,19 @@ class DbConnector:
87
  elif self._db_type == "mysql":
88
  return self._get_tables_mysql(n_sample_rows)
89
  else:
90
- raise ValueError(f"Neizvestnyy tip BD: {self._db_type}")
91
 
92
  def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]:
93
- path = self._safe_sqlite_path(self._sqlite_path())
94
- conn = sqlite3.connect(str(path))
 
 
 
 
 
 
 
 
95
  conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
96
  try:
97
  cur = conn.cursor()
@@ -118,8 +130,9 @@ class DbConnector:
118
  try:
119
  cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
120
  samples = cur.fetchall()
121
- except sqlite3.Error:
122
- pass
 
123
  tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
124
  return tables
125
  finally:
@@ -129,7 +142,7 @@ class DbConnector:
129
  try:
130
  import psycopg2 # type: ignore
131
  except ImportError as e:
132
- raise ImportError("Ustanovi psycopg2: pip install psycopg2-binary") from e
133
 
134
  conn = psycopg2.connect(self.connection_string)
135
  try:
@@ -166,7 +179,7 @@ class DbConnector:
166
  try:
167
  import pymysql # type: ignore
168
  except ImportError as e:
169
- raise ImportError("Ustanovi pymysql: pip install pymysql") from e
170
 
171
  parsed = urlparse(self.connection_string)
172
  conn = pymysql.connect(
@@ -207,22 +220,22 @@ class DbConnector:
207
  return Path(cs)
208
 
209
  @staticmethod
210
- def _safe_sqlite_path(path: Path) -> Path:
211
- """Esli ryadom s BD est journal-fayl, kopΠΈΡ€ΡƒΠ΅ΠΌ fayl vo vremennuyu direktoriu."""
212
- import shutil
213
- import tempfile
214
- journal = Path(str(path) + "-journal")
215
- wal = Path(str(path) + "-wal")
216
- if journal.exists() or wal.exists():
217
- tmp = Path(tempfile.mktemp(suffix=".sqlite"))
218
- shutil.copy2(path, tmp)
219
- return tmp
220
- return path
221
 
222
  @staticmethod
223
  def _normalize(cs: str) -> str:
224
- """Esli peredan prosto put k faylu -- prevraschaem v sqlite:// URI."""
 
 
 
 
 
 
225
  cs = cs.strip()
 
 
226
  if cs.endswith(".sqlite") or cs.endswith(".db"):
227
  return f"sqlite:///{cs}"
228
  return cs
@@ -235,4 +248,4 @@ class DbConnector:
235
  return "postgresql"
236
  if cs.startswith("mysql"):
237
  return "mysql"
238
- raise ValueError(f"Ne udalos opredelit tip BD: {cs}")
 
1
+ """DbConnector β€” ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ ΠΊ ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠΉ Π‘Π” ΠΈ Ρ‡Ρ‚Π΅Π½ΠΈΠ΅ схСмы.
2
 
3
+ ΠŸΠΎΠ΄Π΄Π΅Ρ€ΠΆΠΈΠ²Π°Π΅ΠΌΡ‹Π΅ Ρ‚ΠΈΠΏΡ‹ Π‘Π”:
4
+ SQLite β€” ΠΏΡƒΡ‚ΡŒ ΠΊ Ρ„Π°ΠΉΠ»Ρƒ: "sqlite:///path/to/db.sqlite" ΠΈΠ»ΠΈ просто ΠΏΡƒΡ‚ΡŒ
5
+ PostgreSQL β€” "postgresql://user:pass@host:port/dbname" (Ρ‚Ρ€Π΅Π±ΡƒΠ΅Ρ‚ psycopg2)
6
+ MySQL β€” "mysql://user:pass@host:port/dbname" (Ρ‚Ρ€Π΅Π±ΡƒΠ΅Ρ‚ pymysql)
7
 
8
+ ΠŸΡ€ΠΈΠΌΠ΅Ρ€:
9
  conn = DbConnector("sqlite:///data/demo/sales.sqlite")
10
  print(conn.render_schema())
11
  tables = conn.list_tables()
 
13
 
14
  from __future__ import annotations
15
 
16
+ import logging
17
  import sqlite3
18
  from dataclasses import dataclass, field
19
  from pathlib import Path
20
  from urllib.parse import urlparse
21
 
22
+ logger = logging.getLogger(__name__)
23
+
24
 
25
  @dataclass
26
  class ColumnInfo:
 
37
  sample_rows: list[tuple] = field(default_factory=list)
38
 
39
  def to_ddl(self) -> str:
40
+ """Π“Π΅Π½Π΅Ρ€ΠΈΡ€ΡƒΠ΅Ρ‚ CREATE TABLE statement ΠΈΠ· ΠΌΠ΅Ρ‚Π°Π΄Π°Π½Π½Ρ‹Ρ…."""
41
  col_parts = []
42
  for col in self.columns:
43
  line = f" {col.name} {col.type}"
 
50
 
51
 
52
  class DbConnector:
53
+ """Π£Π½ΠΈΠ²Π΅Ρ€ΡΠ°Π»ΡŒΠ½Ρ‹ΠΉ ΠΊΠΎΠ½Π½Π΅ΠΊΡ‚ΠΎΡ€ ΠΊ Π‘Π”. Π§ΠΈΡ‚Π°Π΅Ρ‚ схСму для подстановки Π² ΠΏΡ€ΠΎΠΌΠΏΡ‚."""
54
 
55
  def __init__(self, connection_string: str, n_sample_rows: int = 2):
56
  self.connection_string = self._normalize(connection_string)
 
69
  for t in tables:
70
  parts.append(t.to_ddl())
71
  if include_samples and t.sample_rows:
72
+ parts.append(f"-- ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ строк ΠΈΠ· {t.name}:")
73
  for row in t.sample_rows:
74
  parts.append(f"-- {row}")
75
  parts.append("")
 
79
  try:
80
  self._get_tables(n_sample_rows=0)
81
  return True
82
+ except Exception as e: # noqa: BLE001
83
+ logger.warning("ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ ΠΊ Π‘Π” Π½Π΅ ΡƒΠ΄Π°Π»ΠΎΡΡŒ: %s", e)
84
  return False
85
 
86
  def _get_tables(self, n_sample_rows: int) -> list[TableInfo]:
 
91
  elif self._db_type == "mysql":
92
  return self._get_tables_mysql(n_sample_rows)
93
  else:
94
+ raise ValueError(f"НСизвСстный Ρ‚ΠΈΠΏ Π‘Π”: {self._db_type}")
95
 
96
  def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]:
97
+ """SQLite-ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ Π² Ρ€Π΅ΠΆΠΈΠΌΠ΅ read-only Ρ‡Π΅Ρ€Π΅Π· URI.
98
+
99
+ immutable=1 Π³ΠΎΠ²ΠΎΡ€ΠΈΡ‚ SQLite, Ρ‡Ρ‚ΠΎ Ρ„Π°ΠΉΠ» Π½Π΅ измСняСтся Π²ΠΎ врСмя сСссии,
100
+ поэтому journal/WAL-Ρ„Π°ΠΉΠ»Ρ‹ ΠΌΠΎΠΆΠ½ΠΎ ΠΈΠ³Π½ΠΎΡ€ΠΈΡ€ΠΎΠ²Π°Ρ‚ΡŒ. Π­Ρ‚ΠΎ ΡƒΠ±ΠΈΡ€Π°Π΅Ρ‚ ΠΏΡ€Π΅ΠΆΠ½ΡŽΡŽ
101
+ Π»ΠΎΠ³ΠΈΠΊΡƒ с ΠΊΠΎΠΏΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ΠΌ Π‘Π” Π²ΠΎ Π²Ρ€Π΅ΠΌΠ΅Π½Π½ΡƒΡŽ Π΄ΠΈΡ€Π΅ΠΊΡ‚ΠΎΡ€ΠΈΡŽ ΠΈ Π·Π°ΠΎΠ΄Π½ΠΎ Π΄Π°Ρ‘Ρ‚
102
+ guardrail-ΡƒΡ€ΠΎΠ²Π΅Π½ΡŒ бСзопасности: любая ΠΌΠΎΠ΄ΠΈΡ„ΠΈΡ†ΠΈΡ€ΡƒΡŽΡ‰Π°Ρ опСрация
103
+ Π½Π° Ρ‚Π°ΠΊΠΎΠΌ соСдинСнии Π·Π°Π²Π΅Ρ€ΡˆΠΈΡ‚ΡΡ ошибкой.
104
+ """
105
+ path = self._sqlite_path()
106
+ conn = sqlite3.connect(self._sqlite_uri(path), uri=True)
107
  conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
108
  try:
109
  cur = conn.cursor()
 
130
  try:
131
  cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
132
  samples = cur.fetchall()
133
+ except sqlite3.Error as e:
134
+ logger.debug("НС ΡƒΠ΄Π°Π»ΠΎΡΡŒ ΠΏΠΎΠ»ΡƒΡ‡ΠΈΡ‚ΡŒ sample-строки для %s: %s",
135
+ name, e)
136
  tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
137
  return tables
138
  finally:
 
142
  try:
143
  import psycopg2 # type: ignore
144
  except ImportError as e:
145
+ raise ImportError("Установи psycopg2: pip install psycopg2-binary") from e
146
 
147
  conn = psycopg2.connect(self.connection_string)
148
  try:
 
179
  try:
180
  import pymysql # type: ignore
181
  except ImportError as e:
182
+ raise ImportError("Установи pymysql: pip install pymysql") from e
183
 
184
  parsed = urlparse(self.connection_string)
185
  conn = pymysql.connect(
 
220
  return Path(cs)
221
 
222
  @staticmethod
223
+ def _sqlite_uri(path: Path) -> str:
224
+ """Read-only URI для SQLite с ΠΈΠ³Π½ΠΎΡ€ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ΠΌ journal/WAL."""
225
+ return f"file:{path}?mode=ro&immutable=1"
 
 
 
 
 
 
 
 
226
 
227
  @staticmethod
228
  def _normalize(cs: str) -> str:
229
+ """Если ΠΏΠ΅Ρ€Π΅Π΄Π°Π½ просто ΠΏΡƒΡ‚ΡŒ ΠΊ Ρ„Π°ΠΉΠ»Ρƒ β€” ΠΏΡ€Π΅Π²Ρ€Π°Ρ‰Π°Π΅ΠΌ Π² sqlite:// URI.
230
+
231
+ Если строка ΡƒΠΆΠ΅ выглядит ΠΊΠ°ΠΊ URI (sqlite/postgres/mysql) β€”
232
+ Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅ΠΌ ΠΊΠ°ΠΊ Π΅ΡΡ‚ΡŒ. Π‘Π΅Π· этой ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΠΈ сцСнарий Β«ΠΏΠ΅Ρ€Π΅Π΄Π°Π»ΠΈ
233
+ ΠΊΠΎΡ€Ρ€Π΅ΠΊΡ‚Π½Ρ‹ΠΉ sqlite:///pathΒ» ΠΏΡ€ΠΈΠ²ΠΎΠ΄ΠΈΠ» ΠΊ Π΄Π²ΠΎΠΉΠ½ΠΎΠΉ Π½ΠΎΡ€ΠΌΠ°Π»ΠΈΠ·Π°Ρ†ΠΈΠΈ
234
+ ΠΈ ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡŽ ΠΊ Π½Π΅ΡΡƒΡ‰Π΅ΡΡ‚Π²ΡƒΡŽΡ‰Π΅ΠΌΡƒ ΠΏΡƒΡ‚ΠΈ.
235
+ """
236
  cs = cs.strip()
237
+ if cs.startswith(("sqlite:", "postgres", "mysql")):
238
+ return cs
239
  if cs.endswith(".sqlite") or cs.endswith(".db"):
240
  return f"sqlite:///{cs}"
241
  return cs
 
248
  return "postgresql"
249
  if cs.startswith("mysql"):
250
  return "mysql"
251
+ raise ValueError(f"НС ΡƒΠ΄Π°Π»ΠΎΡΡŒ ΠΎΠΏΡ€Π΅Π΄Π΅Π»ΠΈΡ‚ΡŒ Ρ‚ΠΈΠΏ Π‘Π”: {cs}")
src/db/executor.py CHANGED
@@ -1,23 +1,25 @@
1
- """SqlExecutor -- vypolnyaet SQL-zapros na podklyuchennoy BD i vozvraschaet rezultat.
2
 
3
- Primer:
4
- executor = SqlExecutor("sqlite:///data/demo/sales.sqlite")
5
- result = executor.run("SELECT SUM(amount) FROM orders WHERE status='paid'")
6
- print(result.columns)
7
- print(result.rows)
8
  """
9
 
10
  from __future__ import annotations
11
 
 
12
  import sqlite3
13
- from dataclasses import dataclass, field
14
  from pathlib import Path
15
  from urllib.parse import urlparse
16
 
 
 
17
 
18
  @dataclass
19
  class QueryResult:
20
- """Rezultat vypolneniya SQL-zaprosa."""
21
  columns: list[str]
22
  rows: list[list]
23
  row_count: int
@@ -39,9 +41,9 @@ class QueryResult:
39
 
40
  def to_markdown_table(self) -> str:
41
  if self.error:
42
- return f"Oshibka: {self.error}"
43
  if not self.rows:
44
- return "(pustoy rezultat)"
45
  header = " | ".join(self.columns)
46
  sep = " | ".join(["---"] * len(self.columns))
47
  rows = "\n".join(" | ".join(str(v) for v in row) for row in self.rows)
@@ -49,7 +51,7 @@ class QueryResult:
49
 
50
 
51
  class SqlExecutor:
52
- """Vypolnyaet SQL na podklyuchennoy BD."""
53
 
54
  MAX_ROWS = 500
55
 
@@ -67,13 +69,14 @@ class SqlExecutor:
67
  return self._run_mysql(sql)
68
  else:
69
  return QueryResult(columns=[], rows=[], row_count=0, sql=sql,
70
- error=f"Neizvestnyy tip BD: {self._db_type}")
71
- except Exception as e:
 
72
  return QueryResult(columns=[], rows=[], row_count=0, sql=sql, error=str(e))
73
 
74
  def _run_sqlite(self, sql: str) -> QueryResult:
75
- path = self._safe_sqlite_path(self._sqlite_path())
76
- conn = sqlite3.connect(str(path))
77
  conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
78
  try:
79
  cur = conn.cursor()
@@ -88,10 +91,12 @@ class SqlExecutor:
88
  try:
89
  import psycopg2 # type: ignore
90
  except ImportError as e:
91
- raise ImportError("Ustanovi psycopg2: pip install psycopg2-binary") from e
92
 
93
  conn = psycopg2.connect(self.connection_string)
94
  try:
 
 
95
  cur = conn.cursor()
96
  cur.execute(sql)
97
  cols = [d[0] for d in (cur.description or [])]
@@ -104,7 +109,7 @@ class SqlExecutor:
104
  try:
105
  import pymysql # type: ignore
106
  except ImportError as e:
107
- raise ImportError("Ustanovi pymysql: pip install pymysql") from e
108
 
109
  parsed = urlparse(self.connection_string)
110
  conn = pymysql.connect(
@@ -116,6 +121,9 @@ class SqlExecutor:
116
  )
117
  try:
118
  cur = conn.cursor()
 
 
 
119
  cur.execute(sql)
120
  cols = [d[0] for d in (cur.description or [])]
121
  rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
@@ -130,16 +138,9 @@ class SqlExecutor:
130
  return Path(cs)
131
 
132
  @staticmethod
133
- def _safe_sqlite_path(path: Path) -> Path:
134
- import shutil
135
- import tempfile
136
- journal = Path(str(path) + "-journal")
137
- wal = Path(str(path) + "-wal")
138
- if journal.exists() or wal.exists():
139
- tmp = Path(tempfile.mktemp(suffix=".sqlite"))
140
- shutil.copy2(path, tmp)
141
- return tmp
142
- return path
143
 
144
  @staticmethod
145
  def _detect_type(cs: str) -> str:
@@ -149,4 +150,4 @@ class SqlExecutor:
149
  return "postgresql"
150
  if cs.startswith("mysql"):
151
  return "mysql"
152
- raise ValueError(f"Ne udalos opredelit tip BD: {cs}")
 
1
+ """SqlExecutor β€” выполняСт SQL-запрос Π½Π° ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Ρ‘Π½Π½ΠΎΠΉ Π‘Π”.
2
 
3
+ Для SQLite соСдинСниС открываСтся Ρ‡Π΅Ρ€Π΅Π· URI с ``mode=ro&immutable=1`` β€”
4
+ это обСспСчиваСт read-only Π±Π΅Π· копирования Ρ„Π°ΠΉΠ»Π° ΠΈ Ρ€Π΅ΠΆΠ΅Ρ‚ Π»ΡŽΠ±Ρ‹Π΅ ΠΏΠΎΠΏΡ‹Ρ‚ΠΊΠΈ
5
+ Π²Ρ‹ΠΏΠΎΠ»Π½ΠΈΡ‚ΡŒ DDL/DML Π½Π° ΡƒΡ€ΠΎΠ²Π½Π΅ Π΄Ρ€Π°ΠΉΠ²Π΅Ρ€Π°. Для PostgreSQL/MySQL ΠΎΡ‚Π΄Π΅Π»ΡŒΠ½Ρ‹ΠΉ
6
+ guardrail остаётся Π½Π° сторонС API (см. is_select_only Π² postprocess.py).
 
7
  """
8
 
9
  from __future__ import annotations
10
 
11
+ import logging
12
  import sqlite3
13
+ from dataclasses import dataclass
14
  from pathlib import Path
15
  from urllib.parse import urlparse
16
 
17
+ logger = logging.getLogger(__name__)
18
+
19
 
20
  @dataclass
21
  class QueryResult:
22
+ """Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚ выполнСния SQL-запроса."""
23
  columns: list[str]
24
  rows: list[list]
25
  row_count: int
 
41
 
42
  def to_markdown_table(self) -> str:
43
  if self.error:
44
+ return f"Ошибка: {self.error}"
45
  if not self.rows:
46
+ return "(пустой Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚)"
47
  header = " | ".join(self.columns)
48
  sep = " | ".join(["---"] * len(self.columns))
49
  rows = "\n".join(" | ".join(str(v) for v in row) for row in self.rows)
 
51
 
52
 
53
  class SqlExecutor:
54
+ """ВыполняСт SQL Π½Π° ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Ρ‘Π½Π½ΠΎΠΉ Π‘Π”."""
55
 
56
  MAX_ROWS = 500
57
 
 
69
  return self._run_mysql(sql)
70
  else:
71
  return QueryResult(columns=[], rows=[], row_count=0, sql=sql,
72
+ error=f"НСизвСстный Ρ‚ΠΈΠΏ Π‘Π”: {self._db_type}")
73
+ except Exception as e: # noqa: BLE001
74
+ logger.warning("Ошибка выполнСния SQL: %s", e)
75
  return QueryResult(columns=[], rows=[], row_count=0, sql=sql, error=str(e))
76
 
77
  def _run_sqlite(self, sql: str) -> QueryResult:
78
+ path = self._sqlite_path()
79
+ conn = sqlite3.connect(self._sqlite_uri(path), uri=True)
80
  conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
81
  try:
82
  cur = conn.cursor()
 
91
  try:
92
  import psycopg2 # type: ignore
93
  except ImportError as e:
94
+ raise ImportError("Установи psycopg2: pip install psycopg2-binary") from e
95
 
96
  conn = psycopg2.connect(self.connection_string)
97
  try:
98
+ # Вранзакция Π² Ρ€Π΅ΠΆΠΈΠΌΠ΅ READ ONLY β€” guardrail Π΄Ρ€Π°ΠΉΠ²Π΅Ρ€Π½ΠΎΠ³ΠΎ уровня.
99
+ conn.set_session(readonly=True, autocommit=False)
100
  cur = conn.cursor()
101
  cur.execute(sql)
102
  cols = [d[0] for d in (cur.description or [])]
 
109
  try:
110
  import pymysql # type: ignore
111
  except ImportError as e:
112
+ raise ImportError("Установи pymysql: pip install pymysql") from e
113
 
114
  parsed = urlparse(self.connection_string)
115
  conn = pymysql.connect(
 
121
  )
122
  try:
123
  cur = conn.cursor()
124
+ # MySQL Π½Π΅ ΠΈΠΌΠ΅Π΅Ρ‚ «глобального» read-only Ρ„Π»Π°Π³Π° Π² Π΄Ρ€Π°ΠΉΠ²Π΅Ρ€Π΅,
125
+ # Π½ΠΎ ΠΌΡ‹ ΠΌΠΎΠΆΠ΅ΠΌ ΡΡ‚Π°Ρ€Ρ‚ΠΎΠ²Π°Ρ‚ΡŒ read-only-Ρ‚Ρ€Π°Π½Π·Π°ΠΊΡ†ΠΈΡŽ.
126
+ cur.execute("START TRANSACTION READ ONLY")
127
  cur.execute(sql)
128
  cols = [d[0] for d in (cur.description or [])]
129
  rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
 
138
  return Path(cs)
139
 
140
  @staticmethod
141
+ def _sqlite_uri(path: Path) -> str:
142
+ """Read-only URI для SQLite с ΠΈοΏ½οΏ½Π½ΠΎΡ€ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ΠΌ journal/WAL."""
143
+ return f"file:{path}?mode=ro&immutable=1"
 
 
 
 
 
 
 
144
 
145
  @staticmethod
146
  def _detect_type(cs: str) -> str:
 
150
  return "postgresql"
151
  if cs.startswith("mysql"):
152
  return "mysql"
153
+ raise ValueError(f"НС ΡƒΠ΄Π°Π»ΠΎΡΡŒ ΠΎΠΏΡ€Π΅Π΄Π΅Π»ΠΈΡ‚ΡŒ Ρ‚ΠΈΠΏ Π‘Π”: {cs}")
src/models/inference.py CHANGED
@@ -1,21 +1,25 @@
1
  """Π—Π°Π³Ρ€ΡƒΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»ΠΈ + LoRA-Π°Π΄Π°ΠΏΡ‚Π΅Ρ€Π° ΠΈ инфСрСнс.
2
 
3
- На дСсктопС/Π½ΠΎΡƒΡ‚Π±ΡƒΠΊΠ΅ Π±Π΅Π· GPU Ρ€Π°Π±ΠΎΡ‚Π°Π΅Ρ‚ Π½Π° CPU. МСдлСнно, Π½ΠΎ достаточно для Ρ€Π°Π·Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ ΠΈ Π΄Π΅ΠΌΠΎ.
4
- На Kaggle/Colab β€” Π½Π° GPU, быстрСС.
5
  """
6
 
7
  from __future__ import annotations
8
 
 
9
  from dataclasses import dataclass
10
  from pathlib import Path
11
 
12
  import torch
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
 
 
15
  from src.config import settings
16
  from src.data.prompt import build_chat_messages
17
  from src.models.postprocess import postprocess
18
 
 
 
19
 
20
  @dataclass
21
  class GenerationResult:
@@ -39,11 +43,18 @@ class InferenceEngine:
39
  self.model = None
40
  self._loaded = False
41
 
 
 
 
 
 
42
  def load(self) -> None:
43
  """Π›Π΅Π½ΠΈΠ²ΠΎ Π³Ρ€ΡƒΠ·ΠΈΠΌ модСль. На CPU Π±Π΅Π· ΠΊΠ²Π°Π½Ρ‚ΠΈΠ·Π°Ρ†ΠΈΠΈ."""
44
  if self._loaded:
45
  return
46
 
 
 
47
  self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
48
  # bfloat16 Π²Π΄Π²ΠΎΠ΅ мСньшС float32 (~6 Π“Π‘ vs ~12 Π“Π‘) ΠΈ поддСрТиваСтся Π½Π° CPU
49
  self.model = AutoModelForCausalLM.from_pretrained(
@@ -57,38 +68,53 @@ class InferenceEngine:
57
  adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
58
  try:
59
  from peft import PeftModel
 
60
  self.model = PeftModel.from_pretrained(self.model, adapter_id)
61
  except ImportError:
62
- pass # peft Π½Π΅ установлСн β€” Ρ€Π°Π±ΠΎΡ‚Π°Π΅ΠΌ Π½Π° Π±Π°Π·ΠΎΠ²ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
 
 
 
63
 
64
  self.model.eval()
65
  self._loaded = True
 
66
 
67
  def generate(
68
  self,
69
  schema: str,
70
  question: str,
 
71
  max_new_tokens: int | None = None,
72
  ) -> GenerationResult:
73
- """ΠŸΡ€ΠΈΠ½ΠΈΠΌΠ°Π΅Ρ‚ schema (тСкст DDL) ΠΈ вопрос, Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ SQL."""
 
 
 
 
 
74
  if not self._loaded:
75
  self.load()
76
 
77
- messages = build_chat_messages(schema, question)
78
  prompt = self.tokenizer.apply_chat_template(
79
  messages, tokenize=False, add_generation_prompt=True
80
  )
81
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
82
 
 
 
 
 
 
 
 
 
 
 
83
  with torch.no_grad():
84
- output_ids = self.model.generate(
85
- **inputs,
86
- max_new_tokens=max_new_tokens or settings.max_new_tokens,
87
- do_sample=settings.do_sample,
88
- temperature=settings.temperature if settings.do_sample else 1.0,
89
- pad_token_id=self.tokenizer.eos_token_id,
90
- )
91
-
92
- new_tokens = output_ids[0][inputs["input_ids"].shape[1] :]
93
  raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
94
  return GenerationResult(sql=postprocess(raw), raw_output=raw)
 
1
  """Π—Π°Π³Ρ€ΡƒΠ·ΠΊΠ° ΠΌΠΎΠ΄Π΅Π»ΠΈ + LoRA-Π°Π΄Π°ΠΏΡ‚Π΅Ρ€Π° ΠΈ инфСрСнс.
2
 
3
+ На дСсктопС/Π½ΠΎΡƒΡ‚Π±ΡƒΠΊΠ΅ Π±Π΅Π· GPU Ρ€Π°Π±ΠΎΡ‚Π°Π΅Ρ‚ Π½Π° CPU. МСдлСнно, Π½ΠΎ достаточно для
4
+ Ρ€Π°Π·Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ ΠΈ Π΄Π΅ΠΌΠΎ. На Kaggle/Colab β€” Π½Π° GPU, быстрСС.
5
  """
6
 
7
  from __future__ import annotations
8
 
9
+ import logging
10
  from dataclasses import dataclass
11
  from pathlib import Path
12
 
13
  import torch
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
 
16
+ from src.business.vocabulary import BusinessVocabulary
17
  from src.config import settings
18
  from src.data.prompt import build_chat_messages
19
  from src.models.postprocess import postprocess
20
 
21
+ logger = logging.getLogger(__name__)
22
+
23
 
24
  @dataclass
25
  class GenerationResult:
 
43
  self.model = None
44
  self._loaded = False
45
 
46
+ @property
47
+ def loaded(self) -> bool:
48
+ """ΠŸΡƒΠ±Π»ΠΈΡ‡Π½ΠΎΠ΅ свойство β€” статус Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠΈ ΠΌΠΎΠ΄Π΅Π»ΠΈ."""
49
+ return self._loaded
50
+
51
  def load(self) -> None:
52
  """Π›Π΅Π½ΠΈΠ²ΠΎ Π³Ρ€ΡƒΠ·ΠΈΠΌ модСль. На CPU Π±Π΅Π· ΠΊΠ²Π°Π½Ρ‚ΠΈΠ·Π°Ρ†ΠΈΠΈ."""
53
  if self._loaded:
54
  return
55
 
56
+ logger.info("Π—Π°Π³Ρ€ΡƒΠ·ΠΊΠ° Π±Π°Π·ΠΎΠ²ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ %s Π½Π° устройство %s",
57
+ self.base_model_name, self.device)
58
  self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
59
  # bfloat16 Π²Π΄Π²ΠΎΠ΅ мСньшС float32 (~6 Π“Π‘ vs ~12 Π“Π‘) ΠΈ поддСрТиваСтся Π½Π° CPU
60
  self.model = AutoModelForCausalLM.from_pretrained(
 
68
  adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
69
  try:
70
  from peft import PeftModel
71
+ logger.info("ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ LoRA-Π°Π΄Π°ΠΏΡ‚Π΅Ρ€Π° %s", adapter_id)
72
  self.model = PeftModel.from_pretrained(self.model, adapter_id)
73
  except ImportError:
74
+ logger.warning("peft Π½Π΅ установлСн, ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ базовая модСль Π±Π΅Π· LoRA")
75
+ except Exception as e: # noqa: BLE001 β€” Π»ΠΎΠ³ достаточСн, Π±Π΅Π· падСния
76
+ logger.warning("НС ΡƒΠ΄Π°Π»ΠΎΡΡŒ ΠΏΠΎΠ΄Π³Ρ€ΡƒΠ·ΠΈΡ‚ΡŒ LoRA-Π°Π΄Π°ΠΏΡ‚Π΅Ρ€ %s: %s",
77
+ adapter_id, e)
78
 
79
  self.model.eval()
80
  self._loaded = True
81
+ logger.info("InferenceEngine Π³ΠΎΡ‚ΠΎΠ² ΠΊ Ρ€Π°Π±ΠΎΡ‚Π΅")
82
 
83
  def generate(
84
  self,
85
  schema: str,
86
  question: str,
87
+ vocabulary: BusinessVocabulary | None = None,
88
  max_new_tokens: int | None = None,
89
  ) -> GenerationResult:
90
+ """ΠŸΡ€ΠΈΠ½ΠΈΠΌΠ°Π΅Ρ‚ schema (тСкст DDL) ΠΈ вопрос, Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ SQL.
91
+
92
+ Если ΠΏΠ΅Ρ€Π΅Π΄Π°Π½ нСпустой ``vocabulary``, бизнСс-Ρ‚Π΅Ρ€ΠΌΠΈΠ½Ρ‹ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ
93
+ ΠΏΠΎΠ΄ΠΌΠ΅ΡˆΠΈΠ²Π°ΡŽΡ‚ΡΡ Π² систСмноС сообщСниС Ρ‡Π΅Ρ€Π΅Π· PromptBuilder.
94
+ Π­Ρ‚ΠΎ соотвСтствуСт Ρ€Π°Π·Π΄Π΅Π»Ρƒ 3.6 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки.
95
+ """
96
  if not self._loaded:
97
  self.load()
98
 
99
+ messages = build_chat_messages(schema, question, vocabulary=vocabulary)
100
  prompt = self.tokenizer.apply_chat_template(
101
  messages, tokenize=False, add_generation_prompt=True
102
  )
103
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
104
 
105
+ # ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ сэмплинга. ΠŸΡ€ΠΈ do_sample=False temperature игнорируСтся,
106
+ # поэтому Π½Π΅ ΠΏΠ΅Ρ€Π΅Π΄Π°Ρ‘ΠΌ Π΅Ρ‘ β€” ΠΈΠ½Π°Ρ‡Π΅ transformers Π²Ρ‹Π²ΠΎΠ΄ΠΈΡ‚ warning.
107
+ gen_kwargs = {
108
+ "max_new_tokens": max_new_tokens or settings.max_new_tokens,
109
+ "do_sample": settings.do_sample,
110
+ "pad_token_id": self.tokenizer.eos_token_id,
111
+ }
112
+ if settings.do_sample:
113
+ gen_kwargs["temperature"] = settings.temperature
114
+
115
  with torch.no_grad():
116
+ output_ids = self.model.generate(**inputs, **gen_kwargs)
117
+
118
+ new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
 
 
 
 
 
 
119
  raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
120
  return GenerationResult(sql=postprocess(raw), raw_output=raw)
src/models/postprocess.py CHANGED
@@ -1,24 +1,69 @@
1
- """ΠŸΠΎΡΡ‚ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° SQL: чистка Π²Ρ‹Π²ΠΎΠ΄Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈ базовая валидация Ρ‡Π΅Ρ€Π΅Π· sqlglot."""
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import re
6
 
7
  import sqlglot
 
8
  from sqlglot.errors import ParseError
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def strip_model_artifacts(text: str) -> str:
12
- """Π£Π±ΠΈΡ€Π°Π΅Ρ‚ markdown-Π±Π»ΠΎΠΊΠΈ, прСфиксы, лишний тСкст послС SQL."""
13
- # ```sql ... ```
14
- m = re.search(r"```(?:sql)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
15
- if m:
16
- text = m.group(1)
 
 
 
 
 
 
17
 
18
- # Π£Π±ΠΈΡ€Π°Π΅ΠΌ "SQL:", "ΠžΡ‚Π²Π΅Ρ‚:" ΠΈ Ρ‚.ΠΏ. Π² Π½Π°Ρ‡Π°Π»Π΅
19
- text = re.sub(r"^\s*(?:SQL|ΠžΡ‚Π²Π΅Ρ‚|Answer)\s*:\s*", "", text, flags=re.IGNORECASE)
 
 
 
20
 
21
- # Если Π΅ΡΡ‚ΡŒ нСсколько SQL β€” Π±Π΅Ρ€Ρ‘ΠΌ ΠΏΠ΅Ρ€Π²Ρ‹ΠΉ Π΄ΠΎ Ρ‚ΠΎΡ‡ΠΊΠΈ с запятой
22
  text = text.strip()
23
  if ";" in text:
24
  head, _, _ = text.partition(";")
@@ -28,23 +73,75 @@ def strip_model_artifacts(text: str) -> str:
28
 
29
 
30
  def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool:
31
- """ΠŸΠ°Ρ€ΡΠΈΡ‚ΡΡ Π»ΠΈ SQL Ρ‡Π΅Ρ€Π΅Π· sqlglot."""
 
 
 
 
 
 
 
 
32
  try:
33
- sqlglot.parse_one(sql, dialect=dialect)
34
- return True
35
- except ParseError:
 
 
36
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  def normalize_sql(sql: str, dialect: str = "sqlite") -> str:
40
- """Нормализация для Exact Match: Π΅Π΄ΠΈΠ½Ρ‹ΠΉ рСгистр ΠΊΠ»ΡŽΡ‡Π΅Π²Ρ‹Ρ… слов, ΠΏΡ€ΠΎΠ±Π΅Π»Ρ‹."""
 
 
 
 
 
41
  try:
42
- return sqlglot.parse_one(sql, dialect=dialect).sql(dialect=dialect, pretty=False).lower()
43
- except ParseError:
44
- # Если Π½Π΅ парсится β€” просто Π½ΠΈΠΆΠ½ΠΈΠΉ рСгистр ΠΈ схлопываниС ΠΏΡ€ΠΎΠ±Π΅Π»ΠΎΠ²
45
- return re.sub(r"\s+", " ", sql.lower()).strip().rstrip(";")
46
 
47
 
48
  def postprocess(raw_output: str) -> str:
49
- """ΠŸΠΎΠ»Π½Ρ‹ΠΉ pipeline постобработки."""
50
- return strip_model_artifacts(raw_output)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ΠŸΠΎΡΡ‚ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° SQL: чистка Π²Ρ‹Π²ΠΎΠ΄Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ, валидация ΠΈ нормализация.
2
+
3
+ БоотвСтствуСт Ρ€Π°Π·Π΄Π΅Π»Ρƒ 2.5 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки. Pipeline:
4
+ raw_output ──► strip_model_artifacts ──► is_valid_sql ──► sql | ""
5
+
6
+ Π”ΠΎΠΏΠΎΠ»Π½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎ ΠΌΠΎΠ΄ΡƒΠ»ΡŒ прСдоставляСт:
7
+ is_select_only(sql) β€” AST-ΡƒΡ€ΠΎΠ²Π½Π΅Π²Ρ‹ΠΉ Π³Π²Π°Ρ€Π΄Π΅ΠΉΠ» ΠΏΡ€ΠΎΡ‚ΠΈΠ² DDL/DML
8
+ ΠΏΠ΅Ρ€Π΅Π΄ Π²Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ΠΌ сгСнСрированного запроса;
9
+ normalize_sql(sql) β€” каноничСская Ρ„ΠΎΡ€ΠΌΠ° для расчёта Exact Match
10
+ (совмСстима с evaluate_pauq.py).
11
+ """
12
 
13
  from __future__ import annotations
14
 
15
+ import logging
16
  import re
17
 
18
  import sqlglot
19
+ from sqlglot import exp
20
  from sqlglot.errors import ParseError
21
 
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # ΠšΠ»ΡŽΡ‡Π΅Π²Ρ‹Π΅ слова, с ΠΊΠΎΡ‚ΠΎΡ€Ρ‹Ρ… ΠΌΠΎΠΆΠ΅Ρ‚ Π½Π°Ρ‡ΠΈΠ½Π°Ρ‚ΡŒΡΡ ΠΊΠΎΡ€Ρ€Π΅ΠΊΡ‚Π½Ρ‹ΠΉ SQL-запрос.
25
+ _SQL_START_KEYWORDS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE")
26
+ _SQL_START_REGEX = re.compile(
27
+ r"\b(" + "|".join(_SQL_START_KEYWORDS) + r")\b",
28
+ flags=re.IGNORECASE,
29
+ )
30
+ _FENCE_REGEX = re.compile(r"```(?:sql)?\s*(.*?)```", flags=re.DOTALL | re.IGNORECASE)
31
+ _PREFIX_REGEX = re.compile(r"^\s*(?:SQL|ΠžΡ‚Π²Π΅Ρ‚|Answer)\s*:\s*", flags=re.IGNORECASE)
32
+
33
+ # Π’ΠΈΠΏΡ‹ AST-ΡƒΠ·Π»ΠΎΠ², ΠΊΠΎΡ‚ΠΎΡ€Ρ‹Π΅ ΠΌΡ‹ считаСм «осмыслСнными» SQL-запросами.
34
+ # sqlglot β€” Π»ΠΎΡΠ»ΡŒΠ½Ρ‹ΠΉ парсСр: 'garbage text' ΠΎΠ½ распарсит ΠΊΠ°ΠΊ Column/Table.
35
+ # Π‘Π΅Π· ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΠΈ isinstance Ρ‚Π°ΠΊΠΈΠ΅ случаи Π±ΡƒΠ΄ΡƒΡ‚ ΠΏΡ€ΠΎΡ…ΠΎΠ΄ΠΈΡ‚ΡŒ is_valid_sql.
36
+ _VALID_ROOT_TYPES: tuple[type[exp.Expression], ...] = (
37
+ exp.Select,
38
+ exp.With,
39
+ exp.Insert,
40
+ exp.Update,
41
+ exp.Delete,
42
+ exp.Union,
43
+ exp.Intersect,
44
+ exp.Except,
45
+ )
46
+
47
 
48
  def strip_model_artifacts(text: str) -> str:
49
+ """ΠžΡ‡ΠΈΡ‰Π°Π΅Ρ‚ Π²Ρ‹Π²ΠΎΠ΄ ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΎΡ‚ markdown ΠΈ пояснСний Π΄ΠΎ Π½Π°Ρ‡Π°Π»Π° SQL-запроса.
50
+
51
+ Π¨Π°Π³ΠΈ:
52
+ 1. Если ΠΎΡ‚Π²Π΅Ρ‚ ΠΎΠ±Ρ‘Ρ€Π½ΡƒΡ‚ Π² ```sql ... ``` β€” извлСкаСтся содСрТимоС.
53
+ 2. Π£Π΄Π°Π»ΡΡŽΡ‚ΡΡ прСфиксы Π²ΠΈΠ΄Π° Β«SQL:Β», Β«ΠžΡ‚Π²Π΅Ρ‚:Β», Β«Answer:Β».
54
+ 3. Π˜Ρ‰Π΅Ρ‚ΡΡ ΠΏΠ΅Ρ€Π²ΠΎΠ΅ Π²Ρ…ΠΎΠΆΠ΄Π΅Π½ΠΈΠ΅ SQL-ΠΊΠ»ΡŽΡ‡Π΅Π²ΠΎΠ³ΠΎ слова, всё Π΄ΠΎ Π½Π΅Π³ΠΎ отбрасываСтся.
55
+ 4. БСрётся ΠΏΠ΅Ρ€Π²Ρ‹ΠΉ statement Π΄ΠΎ ΠΏΠ΅Ρ€Π²ΠΎΠΉ Ρ‚ΠΎΡ‡ΠΊΠΈ с запятой Π²ΠΊΠ»ΡŽΡ‡ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎ.
56
+ """
57
+ fence = _FENCE_REGEX.search(text)
58
+ if fence:
59
+ text = fence.group(1)
60
 
61
+ text = _PREFIX_REGEX.sub("", text)
62
+
63
+ keyword_match = _SQL_START_REGEX.search(text)
64
+ if keyword_match:
65
+ text = text[keyword_match.start():]
66
 
 
67
  text = text.strip()
68
  if ";" in text:
69
  head, _, _ = text.partition(";")
 
73
 
74
 
75
  def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool:
76
+ """ΠŸΡ€ΠΎΠ²Π΅Ρ€ΡΠ΅Ρ‚, Ρ‡Ρ‚ΠΎ строка β€” это Π²Π°Π»ΠΈΠ΄Π½Ρ‹ΠΉ SQL-запрос.
77
+
78
+ ΠŸΠ°Ρ€ΡΠΈΡ‚ΡΡ Ρ‡Π΅Ρ€Π΅Π· sqlglot ΠΈ Π΄ΠΎΠΏΠΎΠ»Π½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎ провСряСтся, Ρ‡Ρ‚ΠΎ ΠΊΠΎΡ€Π΅Π½ΡŒ AST β€”
79
+ это ΠΎΠ΄ΠΈΠ½ ΠΈΠ· «осмыслСнных» Ρ‚ΠΈΠΏΠΎΠ² запроса (SELECT/WITH/INSERT/UPDATE/
80
+ DELETE/UNION). Π‘Π΅Π· ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΠΈ Ρ‚ΠΈΠΏΠ° sqlglot ΠΏΡ€ΠΈΠ½ΠΈΠΌΠ°Π΅Ρ‚ Π·Π° SQL Π΄Π°ΠΆΠ΅
81
+ случайныС ΠΈΠ΄Π΅Π½Ρ‚ΠΈΡ„ΠΈΠΊΠ°Ρ‚ΠΎΡ€Ρ‹, ΠΏΠΎΡ‚ΠΎΠΌΡƒ Ρ‡Ρ‚ΠΎ ΠΎΠ½ Π»ΠΎΡΠ»ΡŒΠ½Ρ‹ΠΉ парсСр.
82
+ """
83
+ if not sql or not sql.strip():
84
+ return False
85
  try:
86
+ parsed = sqlglot.parse_one(sql, dialect=dialect)
87
+ except (ParseError, ValueError, TypeError) as e:
88
+ logger.debug("sqlglot Π½Π΅ смог Ρ€Π°Π·ΠΎΠ±Ρ€Π°Ρ‚ΡŒ SQL: %s", e)
89
+ return False
90
+ if parsed is None:
91
  return False
92
+ return isinstance(parsed, _VALID_ROOT_TYPES)
93
+
94
+
95
+ def is_select_only(sql: str, dialect: str = "sqlite") -> bool:
96
+ """Π’ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ True, Ссли SQL β€” это SELECT (Π² Ρ‚. Ρ‡. Π²Π½ΡƒΡ‚Ρ€ΠΈ WITH-CTE).
97
+
98
+ Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ΡΡ ΠΊΠ°ΠΊ guardrail ΠΏΠ΅Ρ€Π΅Π΄ Π²Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ΠΌ сгСнСрированного запроса
99
+ Π½Π° Ρ€Π΅Π°Π»ΡŒΠ½ΠΎΠΉ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ…: модСль Π½Π΅ Π΄ΠΎΠ»ΠΆΠ½Π° ΠΏΠΎΠ»ΡƒΡ‡ΠΈΡ‚ΡŒ Π²ΠΎΠ·ΠΌΠΎΠΆΠ½ΠΎΡΡ‚ΡŒ Π²Ρ‹Π·Π²Π°Ρ‚ΡŒ
100
+ DROP/UPDATE/DELETE/INSERT, Π΄Π°ΠΆΠ΅ Ссли Ρ‚Π°ΠΊΠΈΠ΅ конструкции синтаксичСски
101
+ ΠΊΠΎΡ€Ρ€Π΅ΠΊΡ‚Π½Ρ‹.
102
+ """
103
+ if not sql or not sql.strip():
104
+ return False
105
+ try:
106
+ parsed = sqlglot.parse_one(sql, dialect=dialect)
107
+ except (ParseError, ValueError, TypeError):
108
+ return False
109
+ if parsed is None:
110
+ return False
111
+ if isinstance(parsed, exp.Select):
112
+ return True
113
+ if isinstance(parsed, exp.With):
114
+ return isinstance(parsed.this, exp.Select)
115
+ if isinstance(parsed, exp.Subquery):
116
+ return isinstance(parsed.this, exp.Select)
117
+ return False
118
 
119
 
120
  def normalize_sql(sql: str, dialect: str = "sqlite") -> str:
121
+ """ΠšΠ°Π½ΠΎΠ½ΠΈΡ‡Π΅ΡΠΊΠ°Ρ Ρ„ΠΎΡ€ΠΌΠ° для расчёта Exact Match.
122
+
123
+ Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅Ρ‚ sqlglot с Ρ„Π»Π°Π³ΠΎΠΌ ``normalize=True`` β€” это Π½ΠΎΡ€ΠΌΠ°Π»ΠΈΠ·ΡƒΠ΅Ρ‚ рСгистр
124
+ ΠΊΠ»ΡŽΡ‡Π΅Π²Ρ‹Ρ… слов ΠΈ ΠΈΠ΄Π΅Π½Ρ‚ΠΈΡ„ΠΈΠΊΠ°Ρ‚ΠΎΡ€ΠΎΠ². Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚ приводится ΠΊ Π²Π΅Ρ€Ρ…Π½Π΅ΠΌΡƒ рСгистру,
125
+ Ρ‡Ρ‚ΠΎΠ±Ρ‹ EM считался ΠΈΠ΄Π΅Π½Ρ‚ΠΈΡ‡Π½ΠΎ эталонной Ρ€Π΅Π°Π»ΠΈΠ·Π°Ρ†ΠΈΠΈ Π² ``evaluate_pauq.py``.
126
+ """
127
  try:
128
+ parsed = sqlglot.parse_one(sql, dialect=dialect)
129
+ return parsed.sql(dialect=dialect, normalize=True).upper()
130
+ except (ParseError, ValueError, TypeError):
131
+ return re.sub(r"\s+", " ", sql.upper()).strip().rstrip(";")
132
 
133
 
134
  def postprocess(raw_output: str) -> str:
135
+ """ΠŸΠΎΠ»Π½Ρ‹ΠΉ pipeline постобработки Π²Ρ‹Π²ΠΎΠ΄Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ.
136
+
137
+ 1. Чистка Π°Ρ€Ρ‚Π΅Ρ„Π°ΠΊΡ‚ΠΎΠ² Ρ‡Π΅Ρ€Π΅Π· :func:`strip_model_artifacts`.
138
+ 2. Валидация Ρ‡Π΅Ρ€Π΅Π· :func:`is_valid_sql`.
139
+ 3. Π’ΠΎΠ·Π²Ρ€Π°Ρ‚ пустой строки ΠΏΡ€ΠΈ ΠΏΡ€ΠΎΠ²Π°Π»Π΅ Π²Π°Π»ΠΈΠ΄Π°Ρ†ΠΈΠΈ.
140
+
141
+ БоотвСтствуСт Ρ€Π°Π·Π΄Π΅Π»Ρƒ 2.5 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки.
142
+ """
143
+ sql = strip_model_artifacts(raw_output)
144
+ if not is_valid_sql(sql):
145
+ logger.warning("postprocess отбросил Π½Π΅Π²Π°Π»ΠΈΠ΄Π½Ρ‹ΠΉ SQL: %r", sql[:120])
146
+ return ""
147
+ return sql
src/streamlit_app.py CHANGED
@@ -1,9 +1,12 @@
1
- import os
2
- import sys
3
 
4
- # ЗапускаСм основноС ΠΏΡ€ΠΈΠ»ΠΎΠΆΠ΅Π½ΠΈΠ΅ ΠΈΠ· корня ΠΏΡ€ΠΎΠ΅ΠΊΡ‚Π°
5
- root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
- os.chdir(root)
7
- sys.path.insert(0, root)
8
 
9
- exec(open(os.path.join(root, "streamlit_app.py")).read())
 
 
 
 
 
 
 
1
+ """Π£ΡΡ‚Π°Ρ€Π΅Π²ΡˆΠΈΠΉ ΠΌΠΎΠ΄ΡƒΠ»ΡŒ.
 
2
 
3
+ Π’ΠΎΡ‡ΠΊΠ° Π²Ρ…ΠΎΠ΄Π° Streamlit-прилоТСния β€” ``streamlit_app.py`` Π² ΠΊΠΎΡ€Π½Π΅ ΠΏΡ€ΠΎΠ΅ΠΊΡ‚Π°.
4
+ Запуск:
 
 
5
 
6
+ streamlit run streamlit_app.py
7
+
8
+ Π­Ρ‚ΠΎΡ‚ Ρ„Π°ΠΉΠ» оставлСн пустым Ρ€Π°Π΄ΠΈ ΠΎΠ±Ρ€Π°Ρ‚Π½ΠΎΠΉ совмСстимости import-ΠΏΡƒΡ‚Π΅ΠΉ.
9
+ Π Π°Π½Π΅Π΅ здСсь Π±Ρ‹Π»Π° ΠΎΠ±Ρ‘Ρ€Ρ‚ΠΊΠ° Ρ‡Π΅Ρ€Π΅Π· ``exec(open(...).read())``, которая
10
+ запускала UI ΠΊΠ°ΠΊ ΠΏΠΎΠ±ΠΎΡ‡Π½Ρ‹ΠΉ эффСкт ΠΈΠΌΠΏΠΎΡ€Ρ‚Π° ΠΏΠ°ΠΊΠ΅Ρ‚Π° ``src`` β€” Ρ‡Ρ‚ΠΎ Π»ΠΎΠΌΠ°Π»ΠΎ
11
+ сбор тСстов ΠΈ статичСский Π°Π½Π°Π»ΠΈΠ·. Π₯Π°ΠΊ ΡƒΠ΄Π°Π»Ρ‘Π½.
12
+ """
streamlit_app.py CHANGED
@@ -1,16 +1,44 @@
1
- """Streamlit-интСрфСйс ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Ρ‹ Ru2SQL."""
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import sys
6
- import time
7
  from pathlib import Path
8
 
 
 
 
 
 
 
 
 
 
 
 
9
  import streamlit as st
10
 
11
  ROOT = Path(__file__).resolve().parent
12
  sys.path.insert(0, str(ROOT))
13
 
 
 
 
 
 
 
 
14
  # ──────────────────────────────────────────────
15
  # ΠšΠΎΠ½Ρ„ΠΈΠ³ΡƒΡ€Π°Ρ†ΠΈΡ страницы
16
  # ──────────────────────────────────────────────
@@ -21,11 +49,10 @@ st.set_page_config(
21
  )
22
 
23
  # ──────────────────────────────────────────────
24
- # CSS
25
  # ──────────────────────────────────────────────
26
  st.markdown("""
27
  <style>
28
- /* ── Π“Π»ΠΎΠ±Π°Π»ΡŒΠ½Ρ‹ΠΉ ΡˆΡ€ΠΈΡ„Ρ‚ ΠΈ Ρ„ΠΎΠ½ ── */
29
  html, body, [data-testid="stAppViewContainer"] {
30
  background-color: #0d1117;
31
  font-size: 16px;
@@ -36,7 +63,6 @@ st.markdown("""
36
  }
37
  [data-testid="stHeader"] { background: transparent; }
38
 
39
- /* ── Π¨Π°ΠΏΠΊΠ° ── */
40
  .app-header {
41
  padding: 32px 0 24px 0;
42
  border-bottom: 1px solid #30363d;
@@ -57,8 +83,6 @@ st.markdown("""
57
  font-weight: 400;
58
  letter-spacing: 0.1px;
59
  }
60
-
61
- /* ── Π‘Π°ΠΉΠ΄Π±Π°Ρ€: сСкции ── */
62
  .sb-label {
63
  font-size: 10px;
64
  font-weight: 700;
@@ -73,32 +97,13 @@ st.markdown("""
73
  border-top: 1px solid #30363d;
74
  margin: 4px 0 0 0;
75
  }
76
-
77
- /* ── Бтатусы ── */
78
  .status-ok { color: #3fb950; font-size: 13px; font-weight: 600; }
79
  .status-err { color: #f85149; font-size: 13px; font-weight: 600; }
80
-
81
- /* ── DB-ΠΏΠ΅Ρ€Π΅ΠΊΠ»ΡŽΡ‡Π°Ρ‚Π΅Π»ΡŒ ── */
82
- div[data-testid="stRadio"] > div {
83
- gap: 4px;
84
- }
85
- div[data-testid="stRadio"] > div > label {
86
- font-size: 14px;
87
- padding: 4px 0;
88
- }
89
-
90
- /* ── Кнопка словаря ── */
91
- .vocab-status {
92
- font-size: 12px;
93
- color: #7d8590;
94
- margin-top: 6px;
95
- }
96
-
97
- /* ── SQL-Π±Π»ΠΎΠΊ ── */
98
  .sql-box {
99
  background: #161b22;
100
  color: #e6edf3;
101
- font-family: 'JetBrains Mono', 'Fira Code', 'Cascadia Code', 'Courier New', monospace;
102
  font-size: 14px;
103
  line-height: 1.7;
104
  padding: 20px 24px;
@@ -108,11 +113,7 @@ st.markdown("""
108
  white-space: pre-wrap;
109
  margin: 14px 0;
110
  }
111
-
112
- /* ── Π’ΠΊΠ»Π°Π΄ΠΊΠΈ ── */
113
  [data-testid="stTabs"] button { font-size: 15px; font-weight: 500; }
114
-
115
- /* ── ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ запросов ── */
116
  .examples-label {
117
  font-size: 11px;
118
  font-weight: 700;
@@ -121,47 +122,21 @@ st.markdown("""
121
  color: #7d8590;
122
  margin: 24px 0 10px 0;
123
  }
124
-
125
- /* ── ПолС Π²Π²ΠΎΠ΄Π° вопроса ── */
126
  [data-testid="stTextArea"] textarea {
127
  font-size: 16px !important;
128
  line-height: 1.6 !important;
129
  }
130
-
131
- /* ── Кнопка Β«Π’Ρ‹ΠΏΠΎΠ»Π½ΠΈΡ‚ΡŒΒ» ── */
132
  [data-testid="stButton"] > button[kind="primary"] {
133
  font-size: 15px;
134
  padding: 10px 28px;
135
  border-radius: 8px;
136
  font-weight: 600;
137
  }
138
-
139
- /* ── ΠœΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ ── */
140
- [data-testid="stMetric"] label {
141
- font-size: 12px !important;
142
- color: #7d8590 !important;
143
- }
144
- [data-testid="stMetricValue"] {
145
- font-size: 22px !important;
146
- color: #e6edf3 !important;
147
- }
148
-
149
- /* ── ΠŸΡ€Π΅Π΄ΡƒΠΏΡ€Π΅ΠΆΠ΄Π΅Π½ΠΈΠ΅ ΠΎ нСготовности ── */
150
- [data-testid="stAlertContainer"] {
151
- border-radius: 8px;
152
- font-size: 14px;
153
- }
154
-
155
- /* ── Expander схСмы ── */
156
- [data-testid="stExpander"] summary {
157
- font-size: 15px;
158
- font-weight: 500;
159
- }
160
-
161
- /* ── Π‘ΠΊΡ€Ρ‹Ρ‚ΡŒ ΠΊΠ½ΠΎΠΏΠΊΡƒ Stop ── */
162
  button[kind="stop"] { display: none !important; }
163
-
164
- /* ── ΠœΠΎΠ΄Π°Π»ΡŒΠ½Ρ‹ΠΉ Π΄ΠΈΠ°Π»ΠΎΠ³ словаря ── */
165
  [data-testid="stDialog"] textarea {
166
  font-family: 'JetBrains Mono', 'Fira Code', monospace !important;
167
  font-size: 13px !important;
@@ -178,27 +153,21 @@ def _default_vocab_yaml() -> str:
178
  example = ROOT / "configs" / "example_vocabulary.yaml"
179
  if example.exists():
180
  return example.read_text(encoding="utf-8")
181
- return (
182
- "company: Моя компания\n\n"
183
- "terms:\n"
184
- " Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°: SUM(orders.amount) WHERE status = 'paid'\n\n"
185
- "filters:\n"
186
- " Ρ‚ΠΎΠ»ΡŒΠΊΠΎ_ΠΎΠΏΠ»Π°Ρ‡Π΅Π½Π½Ρ‹Π΅: orders.status = 'paid'\n\n"
187
- "notes: []\n"
188
- )
189
 
190
 
191
  def _init_state():
192
  defaults = {
193
- "history": [],
194
- "model_loaded": False,
195
- "engine": None,
196
- "db_connector": None,
197
- "db_executor": None,
198
- "vocabulary": None,
199
- "db_connection_string": "",
200
- "vocab_yaml": _default_vocab_yaml(),
201
- "db_mode": None,
 
202
  }
203
  for k, v in defaults.items():
204
  if k not in st.session_state:
@@ -209,58 +178,90 @@ _init_state()
209
 
210
 
211
  # ──────────────────────────────────────────────
212
- # Π’ΡΠΏΠΎΠΌΠΎΠ³Π°Ρ‚Π΅Π»ΡŒΠ½Ρ‹Π΅ Ρ„ΡƒΠ½ΠΊΡ†ΠΈΠΈ
213
  # ──────────────────────────────────────────────
214
- @st.cache_resource(show_spinner="Π˜Π½ΠΈΡ†ΠΈΠ°Π»ΠΈΠ·Π°Ρ†ΠΈΡ модСли…")
215
- def _load_engine():
216
- from src.models.inference import InferenceEngine
217
- engine = InferenceEngine()
218
- engine.load()
219
- return engine
 
 
 
220
 
221
 
222
- def _connect_db(cs: str):
223
- from src.db.connector import DbConnector
224
- from src.db.executor import SqlExecutor
225
- return DbConnector(cs), SqlExecutor(cs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
 
228
- def _load_vocab_from_yaml(yaml_text: str):
 
 
 
 
 
 
 
 
 
 
 
 
229
  import tempfile
230
- from src.business.vocabulary import BusinessVocabulary
231
  tmp = Path(tempfile.mktemp(suffix=".yaml"))
232
  tmp.write_text(yaml_text, encoding="utf-8")
233
- vocab = BusinessVocabulary.from_yaml(tmp)
234
- tmp.unlink(missing_ok=True)
235
- return vocab
236
-
237
-
238
- def _auto_connect_demo():
239
- """ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡ΠΈΡ‚ΡŒ Π΄Π΅ΠΌΠΎ-Π±Π°Π·Ρƒ ΠΈ ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ ΠΊ Π½Π΅ΠΉ."""
240
- demo_path = ROOT / "data" / "demo" / "sales.sqlite"
241
- cs = str(demo_path)
242
  try:
243
- connector, executor = _connect_db(cs)
244
- st.session_state.db_connector = connector
245
- st.session_state.db_executor = executor
246
- st.session_state.db_connection_string = cs
247
- if st.session_state.vocabulary is None:
248
- vocab_path = ROOT / "configs" / "example_vocabulary.yaml"
249
- if vocab_path.exists():
250
- st.session_state.vocabulary = _load_vocab_from_yaml(
251
- vocab_path.read_text(encoding="utf-8")
252
- )
253
- except Exception:
254
- pass
255
 
256
 
257
  # ──────────────────────────────────────────────
258
- # ΠœΠΎΠ΄Π°Π»ΡŒΠ½Ρ‹ΠΉ Π΄ΠΈΠ°Π»ΠΎΠ³ бизнСс-словаря
259
  # ──────────────────────────────────────────────
260
  @st.dialog("БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ", width="large")
261
  def vocab_dialog():
262
  st.caption(
263
- "ΠžΠΏΠΈΡˆΠΈΡ‚Π΅ Ρ‚Π΅Ρ€ΠΌΠΈΠ½Ρ‹, ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ ΠΈ ΠΏΡ€Π°Π²ΠΈΠ»Π° вашСй ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ Π² Ρ„ΠΎΡ€ΠΌΠ°Ρ‚Π΅ YAML. "
264
  "МодСль Π±ΡƒΠ΄Π΅Ρ‚ ΡƒΡ‡ΠΈΡ‚Ρ‹Π²Π°Ρ‚ΡŒ ΠΈΡ… ΠΏΡ€ΠΈ Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΠΈ SQL."
265
  )
266
  yaml_text = st.text_area(
@@ -269,77 +270,72 @@ def vocab_dialog():
269
  height=480,
270
  label_visibility="collapsed",
271
  )
272
- col1, col2 = st.columns([1, 1])
273
- with col1:
274
- if st.button("ΠŸΡ€ΠΈΠΌΠ΅Π½ΠΈΡ‚ΡŒ", type="primary", use_container_width=True):
275
  try:
276
- vocab = _load_vocab_from_yaml(yaml_text)
277
- st.session_state.vocabulary = vocab
278
  st.session_state.vocab_yaml = yaml_text
279
  st.rerun()
280
  except Exception as e:
281
  st.error(f"Ошибка синтаксиса YAML: {e}")
282
- with col2:
283
- if st.button("ΠžΡ‚ΠΌΠ΅Π½Π°", use_container_width=True):
284
  st.rerun()
285
 
286
 
287
  # ──────────────────────────────────────────────
288
- # Боковая панСль
289
  # ──────────────────────────────────────────────
290
  with st.sidebar:
291
 
292
- # ── МодСль ──
293
- st.markdown('<p class="sb-label">МодСль</p>', unsafe_allow_html=True)
294
- if not st.session_state.model_loaded:
295
- with st.spinner("Π˜Π½ΠΈΡ†ΠΈΠ°Π»ΠΈΠ·Π°Ρ†ΠΈΡβ€¦"):
296
- try:
297
- st.session_state.engine = _load_engine()
298
- st.session_state.model_loaded = True
299
- except Exception as e:
300
- st.error(f"Ошибка: {e}")
301
-
302
- if st.session_state.model_loaded:
303
- st.markdown(
304
- '<span class="status-ok">βœ… Qwen2.5-Coder-3B + QLoRA</span>',
305
- unsafe_allow_html=True,
306
- )
307
  else:
308
- st.markdown(
309
- '<span class="status-err">МодСль Π½Π΅ Π·Π°Π³Ρ€ΡƒΠΆΠ΅Π½Π°</span>',
310
- unsafe_allow_html=True,
311
- )
 
 
 
 
 
 
 
312
 
313
  st.markdown('<hr class="sb-divider">', unsafe_allow_html=True)
314
 
315
  # ── Π‘Π°Π·Π° Π΄Π°Π½Π½Ρ‹Ρ… ──
316
  st.markdown('<p class="sb-label">Π‘Π°Π·Π° Π΄Π°Π½Π½Ρ‹Ρ…</p>', unsafe_allow_html=True)
317
 
318
- _modes = ["Π”Π΅ΠΌΠΎ-Π±Π°Π·Π°", "Π—Π°Π³Ρ€ΡƒΠ·ΠΈΡ‚ΡŒ Ρ„Π°ΠΉΠ»", "Π‘Ρ‚Ρ€ΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ"]
319
- _prev = st.session_state.db_mode
320
  db_mode = st.radio(
321
- "Π˜ΡΡ‚ΠΎΡ‡Π½ΠΈΠΊ Π΄Π°Π½Π½Ρ‹Ρ…",
322
- _modes,
323
- index=_modes.index(_prev) if _prev in _modes else None,
324
  label_visibility="collapsed",
325
  )
326
- if db_mode != _prev:
327
- # ΠŸΡ€ΠΈ смСнС Ρ€Π΅ΠΆΠΈΠΌΠ° сбрасываСм ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅
328
- st.session_state.db_connector = None
329
- st.session_state.db_executor = None
330
  st.session_state.db_mode = db_mode
331
 
332
  cs = ""
333
-
334
  if db_mode == "Π”Π΅ΠΌΠΎ-Π±Π°Π·Π°":
335
  st.caption("ВстроСнная Π±Π°Π·Π°: ΠΈΠ½Ρ‚Π΅Ρ€Π½Π΅Ρ‚-ΠΌΠ°Π³Π°Π·ΠΈΠ½ элСктроники, 120 Π·Π°ΠΊΠ°Π·ΠΎΠ².")
336
- demo_path = ROOT / "data" / "demo" / "sales.sqlite"
337
- cs = str(demo_path)
338
-
339
  elif db_mode == "Π—Π°Π³Ρ€ΡƒΠ·ΠΈΡ‚ΡŒ Ρ„Π°ΠΉΠ»":
340
  uploaded = st.file_uploader(
341
- "SQLite-Ρ„Π°ΠΉΠ» Π±Π°Π·Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ…",
342
- type=["sqlite", "db"],
343
  label_visibility="collapsed",
344
  )
345
  if uploaded:
@@ -349,81 +345,77 @@ with st.sidebar:
349
  cs = str(tmp_db)
350
  else:
351
  st.caption("ΠŸΠ΅Ρ€Π΅Ρ‚Π°Ρ‰ΠΈΡ‚Π΅ .sqlite ΠΈΠ»ΠΈ .db Ρ„Π°ΠΉΠ» сюда")
352
-
353
- else: # Π‘Ρ‚Ρ€ΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ
354
  cs = st.text_input(
355
  "Π‘Ρ‚Ρ€ΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ",
356
  placeholder="postgresql://user:pass@host:5432/db",
357
- value=st.session_state.db_connection_string,
358
  label_visibility="collapsed",
359
  )
360
- st.caption("PostgreSQL Β· MySQL (mysql+pymysql://) Β· SQLite (sqlite:///path)")
361
-
362
- if cs and st.button("ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡ΠΈΡ‚ΡŒΡΡ", use_container_width=True, type="primary"):
363
- try:
364
- connector, executor = _connect_db(cs)
365
- tables = connector.list_tables()
366
- st.session_state.db_connector = connector
367
- st.session_state.db_executor = executor
368
- st.session_state.db_connection_string = cs
 
 
 
 
 
 
 
 
 
369
  if "sales" in cs and st.session_state.vocabulary is None:
370
- vocab_path = ROOT / "configs" / "example_vocabulary.yaml"
371
- if vocab_path.exists():
372
- st.session_state.vocabulary = _load_vocab_from_yaml(
373
- vocab_path.read_text(encoding="utf-8")
374
- )
 
 
 
375
  st.success(f"ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΎ. Π’Π°Π±Π»ΠΈΡ†: {len(tables)}")
376
- except Exception as e:
377
- st.error(f"Ошибка ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ: {e}")
378
 
379
- if st.session_state.db_connector:
380
- tables = st.session_state.db_connector.list_tables()
381
  st.markdown(
382
  '<span class="status-ok">βœ… Π‘Π°Π·Π° Π΄Π°Π½Π½Ρ‹Ρ… ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½Π°</span>',
383
  unsafe_allow_html=True,
384
  )
385
- with st.expander(f"Π’Π°Π±Π»ΠΈΡ†Ρ‹ ({len(tables)})"):
386
- for t in tables:
387
- st.code(t, language=None)
388
 
389
  st.markdown('<hr class="sb-divider">', unsafe_allow_html=True)
390
 
391
  # ── БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ ──
392
  st.markdown('<p class="sb-label">БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ</p>', unsafe_allow_html=True)
393
-
394
  if st.session_state.vocabulary:
395
  v = st.session_state.vocabulary
396
  label = v.company if v.company else "Π—Π°Π³Ρ€ΡƒΠΆΠ΅Π½"
397
- st.markdown(
398
- f'<span class="status-ok">βœ… {label}</span>',
399
- unsafe_allow_html=True,
400
- )
401
  if v.terms:
402
- st.markdown(
403
- f'<span class="vocab-status">Π’Π΅Ρ€ΠΌΠΈΠ½ΠΎΠ²: {len(v.terms)}</span>',
404
- unsafe_allow_html=True,
405
- )
406
  else:
407
- st.markdown(
408
- '<span class="vocab-status">Π‘Π»ΠΎΠ²Π°Ρ€ΡŒ Π½Π΅ ΠΏΡ€ΠΈΠΌΠ΅Π½Ρ‘Π½</span>',
409
- unsafe_allow_html=True,
410
- )
411
-
412
- if st.button("Π Π΅Π΄Π°ΠΊΡ‚ΠΈΡ€ΠΎΠ²Π°Ρ‚ΡŒ ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ", use_container_width=True):
413
  vocab_dialog()
414
 
415
 
416
-
417
-
418
  # ──────────────────────────────────────────────
419
- # Основная ΠΎΠ±Π»Π°ΡΡ‚ΡŒ β€” шапка
420
  # ──────────────────────────────────────────────
421
  st.markdown("""
422
  <div class="app-header">
423
  <p class="app-title">Ru2SQL β€” гСнСративная модСль прСобразования запросов<br>
424
  ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ… Π½Π° русском языкС Π² запросы Π½Π° языкС SQL</p>
425
  <p class="app-subtitle">
426
- Qwen2.5-Coder-3B-Instruct &nbsp;Β·&nbsp; QLoRA fine-tuning Π½Π° датасСтС PAUQ
427
  &nbsp;Β·&nbsp; SQLite / PostgreSQL / MySQL
428
  </p>
429
  </div>
@@ -431,15 +423,21 @@ st.markdown("""
431
 
432
  tab_query, tab_schema, tab_history = st.tabs(["Запрос", "Π‘Ρ…Π΅ΠΌΠ° Π±Π°Π·Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ…", "Π˜ΡΡ‚ΠΎΡ€ΠΈΡ"])
433
 
434
- # ──────────── Π’ΠΊΠ»Π°Π΄ΠΊΠ°: Запрос ────────────
 
435
  with tab_query:
436
- ready = st.session_state.model_loaded and st.session_state.db_connector is not None
 
 
 
 
 
437
 
438
  if not ready:
439
  missing = []
440
- if not st.session_state.model_loaded:
441
- missing.append("модСль инициализируСтся")
442
- if st.session_state.db_connector is None:
443
  missing.append("Π±Π°Π·Π° Π΄Π°Π½Π½Ρ‹Ρ… Π½Π΅ ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½Π°")
444
  st.warning("БистСма Π½Π΅ Π³ΠΎΡ‚ΠΎΠ²Π°: " + ", ".join(missing) + ". Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠΉΡ‚Π΅ панСль слСва.")
445
 
@@ -450,19 +448,19 @@ with tab_query:
450
  disabled=not ready,
451
  )
452
 
453
- col_btn, col_spacer = st.columns([1, 5])
454
  with col_btn:
455
  run_btn = st.button(
456
  "Π’Ρ‹ΠΏΠΎΠ»Π½ΠΈΡ‚ΡŒ",
457
  type="primary",
458
  disabled=not ready or not question.strip(),
459
- use_container_width=True,
460
  )
461
 
462
  # ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ для Π΄Π΅ΠΌΠΎ-Π±Π°Π·Ρ‹
463
  if (
464
- st.session_state.db_connection_string
465
- and "sales" in st.session_state.db_connection_string
466
  ):
467
  st.markdown('<p class="examples-label">ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ запросов</p>', unsafe_allow_html=True)
468
  ex_cols = st.columns(3)
@@ -473,78 +471,76 @@ with tab_query:
473
  ]
474
  for i, ex in enumerate(examples):
475
  with ex_cols[i]:
476
- if st.button(ex, key=f"ex_{i}", use_container_width=True):
477
  question = ex
478
- run_btn = True
479
 
480
  if run_btn and question.strip():
481
- engine = st.session_state.engine
482
- connector = st.session_state.db_connector
483
- executor = st.session_state.db_executor
484
- vocab = st.session_state.vocabulary
485
-
486
- enriched = vocab.enrich_prompt(question) if vocab else question
487
- schema = connector.render_schema(include_samples=True)
488
 
489
- with st.spinner("ГСнСрация SQL-запроса…"):
490
- t0 = time.time()
491
- result = engine.generate(schema, enriched)
492
- gen_time = time.time() - t0
 
 
493
 
494
  st.markdown("**Π‘Π³Π΅Π½Π΅Ρ€ΠΈΡ€ΠΎΠ²Π°Π½Π½Ρ‹ΠΉ SQL**")
495
- st.markdown(f'<div class="sql-box">{result.sql}</div>', unsafe_allow_html=True)
496
 
497
- qr = None
498
- if result.sql.strip():
499
- with st.spinner("Π’Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ запроса…"):
500
- qr = executor.run(result.sql)
501
 
502
  c1, c2, c3 = st.columns(3)
503
  c1.metric("ВрСмя Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΠΈ", f"{gen_time:.1f} с")
504
- if qr:
505
- c2.metric("Π‘Ρ‚Ρ€ΠΎΠΊ ΠΏΠΎΠ»ΡƒΡ‡Π΅Π½ΠΎ", qr.row_count if qr.success else "β€”")
506
- c3.metric("Бтатус", "УспСшно" if qr.success else "Ошибка")
507
-
508
- if qr and qr.success:
509
- if qr.rows:
510
- import pandas as pd
511
- st.markdown("**Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚**")
512
- df = pd.DataFrame(qr.rows, columns=qr.columns)
513
- st.dataframe(df, use_container_width=True)
514
- else:
515
- st.info("Запрос Π²Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ ΡƒΡΠΏΠ΅ΡˆΠ½ΠΎ. Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚ пустой.")
516
- elif qr and not qr.success:
517
- st.error(f"Ошибка выполнСния SQL: {qr.error}")
 
 
518
 
519
  st.session_state.history.append({
520
  "question": question,
521
- "sql": result.sql,
522
- "success": qr.success if qr else False,
523
- "rows": qr.row_count if qr and qr.success else 0,
524
  "time": gen_time,
525
  })
526
 
527
- # ──────────── Π’ΠΊΠ»Π°Π΄ΠΊΠ°: Π‘Ρ…Π΅ΠΌΠ° Π‘Π” ────────────
 
528
  with tab_schema:
529
- if st.session_state.db_connector is None:
530
  st.info("ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡ΠΈΡ‚Π΅ΡΡŒ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ… Ρ‡Π΅Ρ€Π΅Π· панСль слСва.")
531
  else:
532
- connector = st.session_state.db_connector
533
  show_samples = st.toggle("ΠŸΠΎΠΊΠ°Π·Ρ‹Π²Π°Ρ‚ΡŒ ΠΏΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ…", value=True)
534
-
535
- for table in connector.get_schema(include_samples=show_samples):
536
- with st.expander(f"{table.name} β€” {len(table.columns)} ΠΊΠΎΠ»ΠΎΠ½ΠΎΠΊ"):
537
- st.code(table.to_ddl(), language="sql")
538
- if show_samples and table.sample_rows:
539
  import pandas as pd
540
- cols = [c.name for c in table.columns]
541
  st.caption("ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ…:")
542
  st.dataframe(
543
- pd.DataFrame(table.sample_rows, columns=cols),
544
- use_container_width=True,
545
  )
546
 
547
- # ──────────── Π’ΠΊΠ»Π°Π΄ΠΊΠ°: Π˜ΡΡ‚ΠΎΡ€ΠΈΡ ────────────
 
548
  with tab_history:
549
  history = st.session_state.history
550
  if not history:
@@ -554,7 +550,7 @@ with tab_history:
554
  with col_h:
555
  st.markdown(f"**Запросов Π² сСссии: {len(history)}**")
556
  with col_clr:
557
- if st.button("ΠžΡ‡ΠΈΡΡ‚ΠΈΡ‚ΡŒ", use_container_width=True):
558
  st.session_state.history = []
559
  st.rerun()
560
 
 
1
+ """Streamlit-интСрфСйс ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Ρ‹ Ru2SQL.
2
+
3
+ АрхитСктурно β€” ΠΊΠ»ΠΈΠ΅Π½Ρ‚ REST API Π½Π° FastAPI. БоотвСтствуСт Ρ€Π°Π·Π΄Π΅Π»Ρƒ 3.5
4
+ ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки: всС обращСния ΠΊ ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ… ΠΈΠ΄ΡƒΡ‚ Ρ‡Π΅Ρ€Π΅Π·
5
+ HTTP ΠΊ ``src.api.main:app``. Запуск Π΄Π²ΡƒΡ… процСссов:
6
+
7
+ uvicorn src.api.main:app --reload # Π½Π° 127.0.0.1:8000
8
+ streamlit run streamlit_app.py # Π½Π° 127.0.0.1:8501
9
+ """
10
 
11
  from __future__ import annotations
12
 
13
+ import logging
14
+ import os
15
  import sys
16
+ import warnings
17
  from pathlib import Path
18
 
19
+ # ──────────────────────────────────────────────
20
+ # Π“Π»ΡƒΡˆΠΈΠΌ ΡˆΡƒΠΌΠ½Ρ‹Π΅ warning'ΠΈ
21
+ # ──────────────────────────────────────────────
22
+ # Streamlit-watcher Ρ…ΠΎΠ΄ΠΈΡ‚ ΠΏΠΎ всСму ΠΏΠ°ΠΊΠ΅Ρ‚Ρƒ transformers (image-processors)
23
+ # ΠΈ спамит ModuleNotFoundError ΠΏΡ€ΠΎ torchvision. На Ρ€Π°Π±ΠΎΡ‚Ρƒ это Π½Π΅ влияСт β€”
24
+ # Qwen2.5-Coder text-only, torchvision Π½Π΅ Π½ΡƒΠΆΠ΅Π½.
25
+ warnings.filterwarnings("ignore", message=".*torchvision.*")
26
+ logging.getLogger("transformers").setLevel(logging.ERROR)
27
+ logging.getLogger("streamlit.watcher.local_sources_watcher").setLevel(logging.ERROR)
28
+
29
+ import httpx
30
  import streamlit as st
31
 
32
  ROOT = Path(__file__).resolve().parent
33
  sys.path.insert(0, str(ROOT))
34
 
35
+ # БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ парсим локально β€” ΠΎΠ½ Π½Π΅ Ρ‚Ρ€Π΅Π±ΡƒΠ΅Ρ‚ обращСния ΠΊ сСрвСру
36
+ from src.business.vocabulary import BusinessVocabulary
37
+
38
+ API_URL = os.environ.get("RU2SQL_API_URL", "http://127.0.0.1:8000")
39
+ QUERY_TIMEOUT = 1800.0 # 30 ΠΌΠΈΠ½ΡƒΡ‚ β€” фактичСски Π±Π΅Π·Π»ΠΈΠΌΠΈΡ‚
40
+ SHORT_TIMEOUT = 10.0 # для /health, /schema
41
+
42
  # ──────────────────────────────────────────────
43
  # ΠšΠΎΠ½Ρ„ΠΈΠ³ΡƒΡ€Π°Ρ†ΠΈΡ страницы
44
  # ──────────────────────────────────────────────
 
49
  )
50
 
51
  # ──────────────────────────────────────────────
52
+ # CSS β€” ΠΎΡ„ΠΎΡ€ΠΌΠ»Π΅Π½ΠΈΠ΅ Π² стилС Ρ‚Ρ‘ΠΌΠ½ΠΎΠΉ Ρ‚Π΅ΠΌΡ‹ GitHub
53
  # ──────────────────────────────────────────────
54
  st.markdown("""
55
  <style>
 
56
  html, body, [data-testid="stAppViewContainer"] {
57
  background-color: #0d1117;
58
  font-size: 16px;
 
63
  }
64
  [data-testid="stHeader"] { background: transparent; }
65
 
 
66
  .app-header {
67
  padding: 32px 0 24px 0;
68
  border-bottom: 1px solid #30363d;
 
83
  font-weight: 400;
84
  letter-spacing: 0.1px;
85
  }
 
 
86
  .sb-label {
87
  font-size: 10px;
88
  font-weight: 700;
 
97
  border-top: 1px solid #30363d;
98
  margin: 4px 0 0 0;
99
  }
 
 
100
  .status-ok { color: #3fb950; font-size: 13px; font-weight: 600; }
101
  .status-err { color: #f85149; font-size: 13px; font-weight: 600; }
102
+ .status-warn { color: #d29922; font-size: 13px; font-weight: 600; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  .sql-box {
104
  background: #161b22;
105
  color: #e6edf3;
106
+ font-family: 'JetBrains Mono', 'Fira Code', monospace;
107
  font-size: 14px;
108
  line-height: 1.7;
109
  padding: 20px 24px;
 
113
  white-space: pre-wrap;
114
  margin: 14px 0;
115
  }
 
 
116
  [data-testid="stTabs"] button { font-size: 15px; font-weight: 500; }
 
 
117
  .examples-label {
118
  font-size: 11px;
119
  font-weight: 700;
 
122
  color: #7d8590;
123
  margin: 24px 0 10px 0;
124
  }
 
 
125
  [data-testid="stTextArea"] textarea {
126
  font-size: 16px !important;
127
  line-height: 1.6 !important;
128
  }
 
 
129
  [data-testid="stButton"] > button[kind="primary"] {
130
  font-size: 15px;
131
  padding: 10px 28px;
132
  border-radius: 8px;
133
  font-weight: 600;
134
  }
135
+ [data-testid="stMetric"] label { font-size: 12px !important; color: #7d8590 !important; }
136
+ [data-testid="stMetricValue"] { font-size: 22px !important; color: #e6edf3 !important; }
137
+ [data-testid="stAlertContainer"] { border-radius: 8px; font-size: 14px; }
138
+ [data-testid="stExpander"] summary { font-size: 15px; font-weight: 500; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  button[kind="stop"] { display: none !important; }
 
 
140
  [data-testid="stDialog"] textarea {
141
  font-family: 'JetBrains Mono', 'Fira Code', monospace !important;
142
  font-size: 13px !important;
 
153
  example = ROOT / "configs" / "example_vocabulary.yaml"
154
  if example.exists():
155
  return example.read_text(encoding="utf-8")
156
+ return "company: Моя компания\n\nterms: {}\nfilters: {}\nnotes: []\n"
 
 
 
 
 
 
 
157
 
158
 
159
  def _init_state():
160
  defaults = {
161
+ "history": [],
162
+ "api_health": None, # dict | None
163
+ "api_error": None, # str | None
164
+ "connection_string": "",
165
+ "schema_tables": None, # list[TablePayload-like dict] | None
166
+ "schema_error": None,
167
+ "vocabulary": None, # BusinessVocabulary | None
168
+ "vocab_yaml": _default_vocab_yaml(),
169
+ "db_mode": None,
170
+ "warmup_done": False,
171
  }
172
  for k, v in defaults.items():
173
  if k not in st.session_state:
 
178
 
179
 
180
  # ──────────────────────────────────────────────
181
+ # ΠžΠ±Ρ‘Ρ€Ρ‚ΠΊΠΈ Π½Π°Π΄ API
182
  # ──────────────────────────────────────────────
183
+ def _api_get_health() -> dict | None:
184
+ """GET /health. None Ссли API нСдоступСн."""
185
+ try:
186
+ r = httpx.get(f"{API_URL}/health", timeout=SHORT_TIMEOUT)
187
+ r.raise_for_status()
188
+ return r.json()
189
+ except Exception as e:
190
+ st.session_state.api_error = str(e)
191
+ return None
192
 
193
 
194
+ def _api_get_schema(cs: str) -> tuple[list[dict] | None, str | None]:
195
+ """POST /schema. Π’ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ (tables, error)."""
196
+ try:
197
+ r = httpx.post(
198
+ f"{API_URL}/schema",
199
+ json={"connection_string": cs, "include_samples": True},
200
+ timeout=SHORT_TIMEOUT,
201
+ )
202
+ if r.status_code != 200:
203
+ try:
204
+ return None, r.json().get("detail", r.text)
205
+ except Exception:
206
+ return None, r.text
207
+ return r.json().get("tables", []), None
208
+ except Exception as e:
209
+ return None, str(e)
210
+
211
+
212
+ def _api_query(question: str, cs: str, vocab: BusinessVocabulary | None) -> dict:
213
+ """POST /query β€” гСнСрация SQL + ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΠΎΠ΅ исполнСниС."""
214
+ payload = {
215
+ "question": question,
216
+ "connection_string": cs,
217
+ "execute": True,
218
+ }
219
+ if vocab is not None and bool(vocab):
220
+ payload["vocabulary"] = {
221
+ "company": vocab.company,
222
+ "terms": vocab.terms,
223
+ "filters": vocab.filters,
224
+ "notes": vocab.notes,
225
+ }
226
+ r = httpx.post(f"{API_URL}/query", json=payload, timeout=QUERY_TIMEOUT)
227
+ if r.status_code != 200:
228
+ try:
229
+ detail = r.json().get("detail", r.text)
230
+ except Exception:
231
+ detail = r.text
232
+ raise RuntimeError(f"API Π²Π΅Ρ€Π½ΡƒΠ» {r.status_code}: {detail}")
233
+ return r.json()
234
 
235
 
236
+
237
+ def _api_warmup() -> tuple[bool, str | None]:
238
+ """POST /warmup β€” ΠΊΠΎΡ€ΠΎΡ‚ΠΊΠΈΠΉ ΠΏΡ€ΠΎΠ³ΠΎΠ½ для ΠΏΡ€ΠΎΠ³Ρ€Π΅Π²Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ Π½Π° CPU."""
239
+ try:
240
+ r = httpx.post(f"{API_URL}/warmup", timeout=QUERY_TIMEOUT)
241
+ if r.status_code == 200:
242
+ return True, None
243
+ return False, r.text
244
+ except Exception as e:
245
+ return False, str(e)
246
+
247
+
248
+ def _load_vocab_from_yaml(yaml_text: str) -> BusinessVocabulary:
249
  import tempfile
 
250
  tmp = Path(tempfile.mktemp(suffix=".yaml"))
251
  tmp.write_text(yaml_text, encoding="utf-8")
 
 
 
 
 
 
 
 
 
252
  try:
253
+ return BusinessVocabulary.from_yaml(tmp)
254
+ finally:
255
+ tmp.unlink(missing_ok=True)
 
 
 
 
 
 
 
 
 
256
 
257
 
258
  # ──────────────────────────────────────────────
259
+ # Π”ΠΈΠ°Π»ΠΎΠ³ рСдактирования бизнСс-словаря
260
  # ──────────────────────────────────────────────
261
  @st.dialog("БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ", width="large")
262
  def vocab_dialog():
263
  st.caption(
264
+ "ΠžΠΏΠΈΡˆΠΈΡ‚Π΅ Ρ‚Π΅Ρ€ΠΌΠΈΠ½Ρ‹ ΠΈ ΠΌΠ΅Ρ‚Ρ€ΠΈΠΊΠΈ ΠΊΠΎΠΌΠΏΠ°Π½ΠΈΠΈ Π² Ρ„ΠΎΡ€ΠΌΠ°Ρ‚Π΅ YAML. "
265
  "МодСль Π±ΡƒΠ΄Π΅Ρ‚ ΡƒΡ‡ΠΈΡ‚Ρ‹Π²Π°Ρ‚ΡŒ ΠΈΡ… ΠΏΡ€ΠΈ Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΠΈ SQL."
266
  )
267
  yaml_text = st.text_area(
 
270
  height=480,
271
  label_visibility="collapsed",
272
  )
273
+ c1, c2 = st.columns(2)
274
+ with c1:
275
+ if st.button("ΠŸΡ€ΠΈΠΌΠ΅Π½ΠΈΡ‚ΡŒ", type="primary", width='stretch'):
276
  try:
277
+ st.session_state.vocabulary = _load_vocab_from_yaml(yaml_text)
 
278
  st.session_state.vocab_yaml = yaml_text
279
  st.rerun()
280
  except Exception as e:
281
  st.error(f"Ошибка синтаксиса YAML: {e}")
282
+ with c2:
283
+ if st.button("ΠžΡ‚ΠΌΠ΅Π½Π°", width='stretch'):
284
  st.rerun()
285
 
286
 
287
  # ──────────────────────────────────────────────
288
+ # Sidebar
289
  # ──────────────────────────────────────────────
290
  with st.sidebar:
291
 
292
+ # ── API ──
293
+ st.markdown('<p class="sb-label">API</p>', unsafe_allow_html=True)
294
+ health = _api_get_health()
295
+ st.session_state.api_health = health
296
+ if health is None:
297
+ st.markdown('<span class="status-err">API нСдоступСн</span>', unsafe_allow_html=True)
298
+ st.caption(f"АдрСс: {API_URL}")
299
+ st.caption("Запусти Π² ΠΎΡ‚Π΄Π΅Π»ΡŒΠ½ΠΎΠΉ консоли: `uvicorn src.api.main:app --reload`")
300
+ if st.session_state.api_error:
301
+ st.caption(f"ΠŸΡ€ΠΈΡ‡ΠΈΠ½Π°: {st.session_state.api_error[:160]}")
 
 
 
 
 
302
  else:
303
+ if health.get("model_loaded"):
304
+ st.markdown(
305
+ f'<span class="status-ok">βœ… {health.get("base_model", "модСль")}</span>',
306
+ unsafe_allow_html=True,
307
+ )
308
+ else:
309
+ st.markdown(
310
+ '<span class="status-warn">⏳ МодСль Π΅Ρ‰Ρ‘ загруТаСтся</span>',
311
+ unsafe_allow_html=True,
312
+ )
313
+ st.caption("ΠŸΠΎΠ΄ΠΎΠΆΠ΄ΠΈΡ‚Π΅ нСсколько ΠΌΠΈΠ½ΡƒΡ‚ β€” модСль Π΅Ρ‰Ρ‘ инициализируСтся.")
314
 
315
  st.markdown('<hr class="sb-divider">', unsafe_allow_html=True)
316
 
317
  # ── Π‘Π°Π·Π° Π΄Π°Π½Π½Ρ‹Ρ… ──
318
  st.markdown('<p class="sb-label">Π‘Π°Π·Π° Π΄Π°Π½Π½Ρ‹Ρ…</p>', unsafe_allow_html=True)
319
 
320
+ modes = ["Π”Π΅ΠΌΠΎ-Π±Π°Π·Π°", "Π—Π°Π³Ρ€ΡƒΠ·ΠΈΡ‚ΡŒ Ρ„Π°ΠΉΠ»", "Π‘Ρ‚Ρ€ΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ"]
321
+ prev = st.session_state.db_mode
322
  db_mode = st.radio(
323
+ "Π˜ΡΡ‚ΠΎΡ‡Π½ΠΈΠΊ Π΄Π°Π½Π½Ρ‹Ρ…", modes,
324
+ index=modes.index(prev) if prev in modes else None,
 
325
  label_visibility="collapsed",
326
  )
327
+ if db_mode != prev:
328
+ st.session_state.schema_tables = None
329
+ st.session_state.connection_string = ""
 
330
  st.session_state.db_mode = db_mode
331
 
332
  cs = ""
 
333
  if db_mode == "Π”Π΅ΠΌΠΎ-Π±Π°Π·Π°":
334
  st.caption("ВстроСнная Π±Π°Π·Π°: ΠΈΠ½Ρ‚Π΅Ρ€Π½Π΅Ρ‚-ΠΌΠ°Π³Π°Π·ΠΈΠ½ элСктроники, 120 Π·Π°ΠΊΠ°Π·ΠΎΠ².")
335
+ cs = str(ROOT / "data" / "demo" / "sales.sqlite")
 
 
336
  elif db_mode == "Π—Π°Π³Ρ€ΡƒΠ·ΠΈΡ‚ΡŒ Ρ„Π°ΠΉΠ»":
337
  uploaded = st.file_uploader(
338
+ "SQLite-Ρ„Π°ΠΉΠ» Π±Π°Π·Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ…", type=["sqlite", "db"],
 
339
  label_visibility="collapsed",
340
  )
341
  if uploaded:
 
345
  cs = str(tmp_db)
346
  else:
347
  st.caption("ΠŸΠ΅Ρ€Π΅Ρ‚Π°Ρ‰ΠΈΡ‚Π΅ .sqlite ΠΈΠ»ΠΈ .db Ρ„Π°ΠΉΠ» сюда")
348
+ else:
 
349
  cs = st.text_input(
350
  "Π‘Ρ‚Ρ€ΠΎΠΊΠ° ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ",
351
  placeholder="postgresql://user:pass@host:5432/db",
352
+ value=st.session_state.connection_string,
353
  label_visibility="collapsed",
354
  )
355
+ st.caption("PostgreSQL Β· MySQL Β· SQLite (sqlite:///path)")
356
+
357
+ if cs and st.button("ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡ΠΈΡ‚ΡŒΡΡ", width='stretch', type="primary"):
358
+ with st.spinner("Π§Ρ‚Π΅Π½ΠΈΠ΅ схСмы…"):
359
+ tables, err = _api_get_schema(cs)
360
+ if err:
361
+ st.error(f"Ошибка ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ: {err}")
362
+ st.session_state.schema_tables = None
363
+ else:
364
+ st.session_state.schema_tables = tables
365
+ st.session_state.connection_string = cs
366
+ st.session_state.schema_error = None
367
+ if not st.session_state.get("warmup_done", False):
368
+ with st.spinner("ΠŸΡ€ΠΎΠ³Ρ€Π΅Π² ΠΌΠΎΠ΄Π΅Π»ΠΈ (запускаСтся ΠΎΠ΄ΠΈΠ½ Ρ€Π°Π· Π·Π° сСссию)…"):
369
+ ok, _err = _api_warmup()
370
+ if ok:
371
+ st.session_state.warmup_done = True
372
+ # Автозагрузка словаря для Π΄Π΅ΠΌΠΎ-Π±Π°Π·Ρ‹
373
  if "sales" in cs and st.session_state.vocabulary is None:
374
+ vp = ROOT / "configs" / "example_vocabulary.yaml"
375
+ if vp.exists():
376
+ try:
377
+ st.session_state.vocabulary = _load_vocab_from_yaml(
378
+ vp.read_text(encoding="utf-8")
379
+ )
380
+ except Exception:
381
+ pass
382
  st.success(f"ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΎ. Π’Π°Π±Π»ΠΈΡ†: {len(tables)}")
 
 
383
 
384
+ if st.session_state.schema_tables is not None:
385
+ n = len(st.session_state.schema_tables)
386
  st.markdown(
387
  '<span class="status-ok">βœ… Π‘Π°Π·Π° Π΄Π°Π½Π½Ρ‹Ρ… ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½Π°</span>',
388
  unsafe_allow_html=True,
389
  )
390
+ with st.expander(f"Π’Π°Π±Π»ΠΈΡ†Ρ‹ ({n})"):
391
+ for t in st.session_state.schema_tables:
392
+ st.code(t.get("name", ""), language=None)
393
 
394
  st.markdown('<hr class="sb-divider">', unsafe_allow_html=True)
395
 
396
  # ── БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ ──
397
  st.markdown('<p class="sb-label">БизнСс-ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ</p>', unsafe_allow_html=True)
 
398
  if st.session_state.vocabulary:
399
  v = st.session_state.vocabulary
400
  label = v.company if v.company else "Π—Π°Π³Ρ€ΡƒΠΆΠ΅Π½"
401
+ st.markdown(f'<span class="status-ok">βœ… {label}</span>', unsafe_allow_html=True)
 
 
 
402
  if v.terms:
403
+ st.caption(f"Π’Π΅Ρ€ΠΌΠΈΠ½ΠΎΠ²: {len(v.terms)}")
 
 
 
404
  else:
405
+ st.caption("Π‘Π»ΠΎΠ²Π°Ρ€ΡŒ Π½Π΅ ΠΏΡ€ΠΈΠΌΠ΅Π½Ρ‘Π½")
406
+ if st.button("Π Π΅Π΄Π°ΠΊΡ‚ΠΈΡ€ΠΎΠ²Π°Ρ‚ΡŒ ΡΠ»ΠΎΠ²Π°Ρ€ΡŒ", width='stretch'):
 
 
 
 
407
  vocab_dialog()
408
 
409
 
 
 
410
  # ──────────────────────────────────────────────
411
+ # Π¨Π°ΠΏΠΊΠ°
412
  # ──────────────────────────────────────────────
413
  st.markdown("""
414
  <div class="app-header">
415
  <p class="app-title">Ru2SQL β€” гСнСративная модСль прСобразования запросов<br>
416
  ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ… Π½Π° русском языкС Π² запросы Π½Π° языкС SQL</p>
417
  <p class="app-subtitle">
418
+ Qwen2.5-Coder-3B-Instruct &nbsp;Β·&nbsp; QLoRA Π½Π° PAUQ
419
  &nbsp;Β·&nbsp; SQLite / PostgreSQL / MySQL
420
  </p>
421
  </div>
 
423
 
424
  tab_query, tab_schema, tab_history = st.tabs(["Запрос", "Π‘Ρ…Π΅ΠΌΠ° Π±Π°Π·Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ…", "Π˜ΡΡ‚ΠΎΡ€ΠΈΡ"])
425
 
426
+
427
+ # ──────────── Tab: Запрос ────────────
428
  with tab_query:
429
+ api_ready = (
430
+ st.session_state.api_health is not None
431
+ and st.session_state.api_health.get("model_loaded", False)
432
+ )
433
+ db_ready = st.session_state.schema_tables is not None
434
+ ready = api_ready and db_ready
435
 
436
  if not ready:
437
  missing = []
438
+ if not api_ready:
439
+ missing.append("API/модСль Π½Π΅ Π³ΠΎΡ‚ΠΎΠ²Ρ‹")
440
+ if not db_ready:
441
  missing.append("Π±Π°Π·Π° Π΄Π°Π½Π½Ρ‹Ρ… Π½Π΅ ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½Π°")
442
  st.warning("БистСма Π½Π΅ Π³ΠΎΡ‚ΠΎΠ²Π°: " + ", ".join(missing) + ". Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠΉΡ‚Π΅ панСль слСва.")
443
 
 
448
  disabled=not ready,
449
  )
450
 
451
+ col_btn, _ = st.columns([1, 5])
452
  with col_btn:
453
  run_btn = st.button(
454
  "Π’Ρ‹ΠΏΠΎΠ»Π½ΠΈΡ‚ΡŒ",
455
  type="primary",
456
  disabled=not ready or not question.strip(),
457
+ width='stretch',
458
  )
459
 
460
  # ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ для Π΄Π΅ΠΌΠΎ-Π±Π°Π·Ρ‹
461
  if (
462
+ st.session_state.connection_string
463
+ and "sales" in st.session_state.connection_string
464
  ):
465
  st.markdown('<p class="examples-label">ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ запросов</p>', unsafe_allow_html=True)
466
  ex_cols = st.columns(3)
 
471
  ]
472
  for i, ex in enumerate(examples):
473
  with ex_cols[i]:
474
+ if st.button(ex, key=f"ex_{i}", width='stretch'):
475
  question = ex
476
+ run_btn = True
477
 
478
  if run_btn and question.strip():
479
+ cs = st.session_state.connection_string
480
+ vocab = st.session_state.vocabulary
 
 
 
 
 
481
 
482
+ with st.spinner("Запрос ΠΊ API. Π­Ρ‚ΠΎ ΠΌΠΎΠΆΠ΅Ρ‚ Π·Π°Π½ΡΡ‚ΡŒ нСсколько ΠΌΠΈΠ½ΡƒΡ‚"):
483
+ try:
484
+ resp = _api_query(question, cs, vocab)
485
+ except Exception as e:
486
+ st.error(f"Ошибка: {e}")
487
+ st.stop()
488
 
489
  st.markdown("**Π‘Π³Π΅Π½Π΅Ρ€ΠΈΡ€ΠΎΠ²Π°Π½Π½Ρ‹ΠΉ SQL**")
490
+ st.markdown(f'<div class="sql-box">{resp.get("sql", "")}</div>', unsafe_allow_html=True)
491
 
492
+ gen_time = resp.get("gen_time_seconds", 0.0)
493
+ execution = resp.get("execution")
494
+ err = resp.get("error")
 
495
 
496
  c1, c2, c3 = st.columns(3)
497
  c1.metric("ВрСмя Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΠΈ", f"{gen_time:.1f} с")
498
+ if execution:
499
+ c2.metric("Π‘Ρ‚Ρ€ΠΎΠΊ ΠΏΠΎΠ»ΡƒΡ‡Π΅Π½ΠΎ", execution.get("row_count", 0))
500
+ c3.metric("Бтатус", "УспСшно")
501
+ elif err:
502
+ c2.metric("Π‘Ρ‚Ρ€ΠΎΠΊ ΠΏΠΎΠ»ΡƒΡ‡Π΅Π½ΠΎ", "β€”")
503
+ c3.metric("Бтатус", "Ошибка")
504
+
505
+ if execution and execution.get("rows"):
506
+ import pandas as pd
507
+ st.markdown("**Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚**")
508
+ df = pd.DataFrame(execution["rows"], columns=execution["columns"])
509
+ st.dataframe(df, width='stretch')
510
+ elif execution and not execution.get("rows"):
511
+ st.info("Запрос Π²Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ ΡƒΡΠΏΠ΅ΡˆΠ½ΠΎ. Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚ пустой.")
512
+ elif err:
513
+ st.error(f"Ошибка выполнСния SQL: {err}")
514
 
515
  st.session_state.history.append({
516
  "question": question,
517
+ "sql": resp.get("sql", ""),
518
+ "success": bool(execution),
519
+ "rows": execution.get("row_count", 0) if execution else 0,
520
  "time": gen_time,
521
  })
522
 
523
+
524
+ # ──────────── Tab: Π‘Ρ…Π΅ΠΌΠ° Π‘Π” ────────────
525
  with tab_schema:
526
+ if st.session_state.schema_tables is None:
527
  st.info("ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡ΠΈΡ‚Π΅ΡΡŒ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ… Ρ‡Π΅Ρ€Π΅Π· панСль слСва.")
528
  else:
 
529
  show_samples = st.toggle("ΠŸΠΎΠΊΠ°Π·Ρ‹Π²Π°Ρ‚ΡŒ ΠΏΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ…", value=True)
530
+ for t in st.session_state.schema_tables:
531
+ with st.expander(f"{t['name']} β€” {len(t['columns'])} ΠΊΠΎΠ»ΠΎΠ½ΠΎΠΊ"):
532
+ st.code(t.get("ddl", ""), language="sql")
533
+ if show_samples and t.get("sample_rows"):
 
534
  import pandas as pd
535
+ cols = [c["name"] for c in t["columns"]]
536
  st.caption("ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ Π΄Π°Π½Π½Ρ‹Ρ…:")
537
  st.dataframe(
538
+ pd.DataFrame(t["sample_rows"], columns=cols),
539
+ width='stretch',
540
  )
541
 
542
+
543
+ # ──────────── Tab: Π˜ΡΡ‚ΠΎΡ€ΠΈΡ ────────────
544
  with tab_history:
545
  history = st.session_state.history
546
  if not history:
 
550
  with col_h:
551
  st.markdown(f"**Запросов Π² сСссии: {len(history)}**")
552
  with col_clr:
553
+ if st.button("ΠžΡ‡ΠΈΡΡ‚ΠΈΡ‚ΡŒ", width='stretch'):
554
  st.session_state.history = []
555
  st.rerun()
556
 
tests/test_db.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ВСсты Π½Π° DbConnector ΠΈ SqlExecutor.
2
+
3
+ ΠŸΠΎΠΊΡ€Ρ‹Π²Π°ΡŽΡ‚ Ρ‡Ρ‚Π΅Π½ΠΈΠ΅ схСмы SQLite-Π±Π°Π·, Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΡŽ DDL ΠΈ ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΡƒ Ρ‚ΠΎΠ³ΠΎ, Ρ‡Ρ‚ΠΎ
4
+ SQLite-ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ Π΄Π΅ΠΉΡΡ‚Π²ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎ открываСтся Π² Ρ€Π΅ΠΆΠΈΠΌΠ΅ read-only β€”
5
+ ΠΌΠΎΠ΄ΠΈΡ„ΠΈΡ†ΠΈΡ€ΡƒΡŽΡ‰ΠΈΠ΅ ΠΎΠΏΠ΅Ρ€Π°Ρ†ΠΈΠΈ Π΄ΠΎΠ»ΠΆΠ½Ρ‹ ΠΏΠ°Π΄Π°Ρ‚ΡŒ с sqlite3.OperationalError.
6
+ """
7
+
8
+ import sqlite3
9
+ from pathlib import Path
10
+
11
+ import pytest
12
+
13
+ from src.db.connector import DbConnector, TableInfo
14
+ from src.db.executor import QueryResult, SqlExecutor
15
+
16
+
17
+ @pytest.fixture
18
+ def tiny_sqlite(tmp_path: Path) -> Path:
19
+ db = tmp_path / "tiny.sqlite"
20
+ conn = sqlite3.connect(db)
21
+ conn.execute(
22
+ "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, city TEXT)"
23
+ )
24
+ conn.executemany(
25
+ "INSERT INTO users (id, name, city) VALUES (?, ?, ?)",
26
+ [(1, "Иван", "Казань"), (2, "Анна", "Москва"), (3, "ОлСг", "Казань")],
27
+ )
28
+ conn.commit()
29
+ conn.close()
30
+ return db
31
+
32
+
33
+ # ──────────────────────────────────────────────────────────────────────
34
+ # DbConnector
35
+ # ──────────────────────────────────────────────────────────────────────
36
+
37
+ def test_connector_lists_tables(tiny_sqlite: Path):
38
+ c = DbConnector(str(tiny_sqlite))
39
+ assert c.list_tables() == ["users"]
40
+
41
+
42
+ def test_connector_reads_columns(tiny_sqlite: Path):
43
+ c = DbConnector(str(tiny_sqlite))
44
+ tables = c.get_schema(include_samples=False)
45
+ assert len(tables) == 1
46
+ table = tables[0]
47
+ assert isinstance(table, TableInfo)
48
+ names = [col.name for col in table.columns]
49
+ assert names == ["id", "name", "city"]
50
+ # id β€” primary key, name β€” NOT NULL
51
+ pk = next(col for col in table.columns if col.name == "id")
52
+ assert pk.primary_key is True
53
+ nn = next(col for col in table.columns if col.name == "name")
54
+ assert nn.nullable is False
55
+
56
+
57
+ def test_connector_renders_ddl(tiny_sqlite: Path):
58
+ c = DbConnector(str(tiny_sqlite))
59
+ schema_text = c.render_schema(include_samples=True)
60
+ assert "CREATE TABLE users" in schema_text
61
+ assert "PRIMARY KEY" in schema_text
62
+ # sample-строки ΠΏΡ€ΠΎΠΊΠΈΠ½ΡƒΡ‚Ρ‹ коммСнтариями
63
+ assert "Иван" in schema_text or "ОлСг" in schema_text
64
+
65
+
66
+ def test_connector_accepts_sqlite_uri(tiny_sqlite: Path):
67
+ c = DbConnector(f"sqlite:///{tiny_sqlite}")
68
+ assert c.list_tables() == ["users"]
69
+
70
+
71
+ # ──────────────────────────────────────────────────────────────────────
72
+ # SqlExecutor
73
+ # ──────────────────────────────────────────────────────────────────────
74
+
75
+ def test_executor_runs_select(tiny_sqlite: Path):
76
+ ex = SqlExecutor(str(tiny_sqlite))
77
+ res = ex.run("SELECT id, name FROM users ORDER BY id")
78
+ assert isinstance(res, QueryResult)
79
+ assert res.success
80
+ assert res.columns == ["id", "name"]
81
+ assert res.row_count == 3
82
+ assert res.rows[0] == [1, "Иван"]
83
+
84
+
85
+ def test_executor_aggregation(tiny_sqlite: Path):
86
+ ex = SqlExecutor(str(tiny_sqlite))
87
+ res = ex.run("SELECT city, COUNT(*) AS cnt FROM users GROUP BY city ORDER BY cnt DESC")
88
+ assert res.success
89
+ assert res.rows[0] == ["Казань", 2]
90
+
91
+
92
+ def test_executor_returns_error_on_bad_sql(tiny_sqlite: Path):
93
+ ex = SqlExecutor(str(tiny_sqlite))
94
+ res = ex.run("SELEC nonsense FROM users")
95
+ assert not res.success
96
+ assert res.error is not None
97
+
98
+
99
+ def test_executor_blocks_modifications(tiny_sqlite: Path):
100
+ """ΠšΠ»ΡŽΡ‡Π΅Π²Π°Ρ ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ°: SQLite-соСдинСниС открываСтся Π² read-only
101
+ Ρ€Π΅ΠΆΠΈΠΌΠ΅ (URI mode=ro&immutable=1), ΠΌΠΎΠ΄ΠΈΡ„ΠΈΡ†ΠΈΡ€ΡƒΡŽΡ‰ΠΈΠ΅ ΠΎΠΏΠ΅Ρ€Π°Ρ†ΠΈΠΈ Π΄ΠΎΠ»ΠΆΠ½Ρ‹
102
+ ΠΏΠ°Π΄Π°Ρ‚ΡŒ ошибкой, Π° Π½Π΅ Π²Ρ‹ΠΏΠΎΠ»Π½ΡΡ‚ΡŒΡΡ Π²Ρ‚ΠΈΡ…ΡƒΡŽ."""
103
+ ex = SqlExecutor(str(tiny_sqlite))
104
+
105
+ res = ex.run("DELETE FROM users WHERE id = 1")
106
+ assert not res.success
107
+ assert res.error is not None
108
+ assert "read" in res.error.lower() or "readonly" in res.error.lower() \
109
+ or "Ρ‚ΠΎΠ»ΡŒΠΊΠΎ для чтСния" in res.error.lower()
110
+
111
+ # ΠŸΠΎΠ΄Ρ‚Π²Π΅Ρ€ΠΆΠ΄Π΅Π½ΠΈΠ΅, Ρ‡Ρ‚ΠΎ Π΄Π°Π½Π½Ρ‹Π΅ Π½Π΅ пострадали
112
+ check = ex.run("SELECT COUNT(*) FROM users")
113
+ assert check.success
114
+ assert check.rows == [[3]]
115
+
116
+
117
+ def test_executor_blocks_drop_table(tiny_sqlite: Path):
118
+ ex = SqlExecutor(str(tiny_sqlite))
119
+ res = ex.run("DROP TABLE users")
120
+ assert not res.success
121
+
122
+ # ΠŸΠΎΠ΄Ρ‚Π²Π΅Ρ€ΠΆΠ΄Π΅Π½ΠΈΠ΅, Ρ‡Ρ‚ΠΎ Ρ‚Π°Π±Π»ΠΈΡ†Π° Π½Π° мСстС
123
+ check = ex.run("SELECT COUNT(*) FROM users")
124
+ assert check.success
125
+
126
+
127
+ def test_queryresult_to_markdown(tiny_sqlite: Path):
128
+ ex = SqlExecutor(str(tiny_sqlite))
129
+ res = ex.run("SELECT id, name FROM users WHERE id = 1")
130
+ md = res.to_markdown_table()
131
+ assert "id" in md and "name" in md
132
+ assert "Иван" in md
tests/test_postprocess.py CHANGED
@@ -1,6 +1,12 @@
1
- """ВСсты Π½Π° постобработку SQL."""
 
 
 
 
 
2
 
3
  from src.models.postprocess import (
 
4
  is_valid_sql,
5
  normalize_sql,
6
  postprocess,
@@ -8,39 +14,165 @@ from src.models.postprocess import (
8
  )
9
 
10
 
11
- def test_strip_markdown_block():
 
 
 
 
12
  raw = "```sql\nSELECT * FROM users;\n```"
13
- assert strip_model_artifacts(raw).startswith("SELECT")
 
 
 
 
 
14
 
15
 
16
  def test_strip_sql_prefix():
17
  raw = "SQL: SELECT 1;"
18
- assert strip_model_artifacts(raw).startswith("SELECT")
19
 
20
 
21
- def test_keeps_first_statement():
 
 
 
 
 
 
 
 
 
 
 
 
22
  raw = "SELECT 1; SELECT 2;"
23
  out = strip_model_artifacts(raw)
24
  assert "SELECT 1" in out
25
  assert "SELECT 2" not in out
26
 
27
 
28
- def test_valid_sql():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  assert is_valid_sql("SELECT * FROM students WHERE id = 1")
30
 
31
 
32
- def test_invalid_sql():
 
 
 
 
33
  assert not is_valid_sql("SELEC * FRM where")
34
 
35
 
36
- def test_normalize_em():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  a = "SELECT * FROM Users"
38
  b = "select * from users"
39
  assert normalize_sql(a) == normalize_sql(b)
40
 
41
 
42
- def test_postprocess_full():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  raw = "```sql\nSELECT name FROM students WHERE group_id = 1;\nSELECT 2;\n```"
44
  out = postprocess(raw)
45
- assert out.startswith("SELECT name")
46
  assert "SELECT 2" not in out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ВСсты Π½Π° постобработку SQL ΠΈ связанныС Ρ„ΡƒΠ½ΠΊΡ†ΠΈΠΈ.
2
+
3
+ ΠŸΠΎΠΊΡ€Ρ‹Π²Π°Π΅Ρ‚ Ρ€Π°Π·Π΄Π΅Π» 2.5 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки: чистку Π°Ρ€Ρ‚Π΅Ρ„Π°ΠΊΡ‚ΠΎΠ²,
4
+ Π²Π°Π»ΠΈΠ΄Π°Ρ†ΠΈΡŽ Ρ‡Π΅Ρ€Π΅Π· sqlglot, Π½ΠΎΡ€ΠΌΠ°Π»ΠΈΠ·Π°Ρ†ΠΈΡŽ для Exact Match ΠΈ AST-ΡƒΡ€ΠΎΠ²Π½Π΅Π²Ρ‹ΠΉ
5
+ Π³Π²Π°Ρ€Π΄Π΅ΠΉΠ» is_select_only.
6
+ """
7
 
8
  from src.models.postprocess import (
9
+ is_select_only,
10
  is_valid_sql,
11
  normalize_sql,
12
  postprocess,
 
14
  )
15
 
16
 
17
+ # ──────────────────────────────────────────────────────────────────────
18
+ # strip_model_artifacts
19
+ # ──────────────────────────────────────────────────────────────────────
20
+
21
+ def test_strip_markdown_block_with_lang():
22
  raw = "```sql\nSELECT * FROM users;\n```"
23
+ assert strip_model_artifacts(raw).upper().startswith("SELECT")
24
+
25
+
26
+ def test_strip_markdown_block_without_lang():
27
+ raw = "```\nSELECT id FROM t;\n```"
28
+ assert strip_model_artifacts(raw).upper().startswith("SELECT")
29
 
30
 
31
  def test_strip_sql_prefix():
32
  raw = "SQL: SELECT 1;"
33
+ assert strip_model_artifacts(raw).upper().startswith("SELECT")
34
 
35
 
36
+ def test_strip_russian_prefix():
37
+ raw = "ΠžΡ‚Π²Π΅Ρ‚: SELECT name FROM students;"
38
+ assert strip_model_artifacts(raw).upper().startswith("SELECT")
39
+
40
+
41
+ def test_strip_natural_language_before_select():
42
+ raw = "Π’ΠΎΡ‚ SQL, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΎΡ‚Π²Π΅Ρ‡Π°Π΅Ρ‚ Π½Π° вопрос: SELECT * FROM t WHERE id = 1;"
43
+ out = strip_model_artifacts(raw)
44
+ assert out.upper().startswith("SELECT")
45
+ assert "Π’ΠΎΡ‚" not in out
46
+
47
+
48
+ def test_keeps_first_statement_of_two():
49
  raw = "SELECT 1; SELECT 2;"
50
  out = strip_model_artifacts(raw)
51
  assert "SELECT 1" in out
52
  assert "SELECT 2" not in out
53
 
54
 
55
+ def test_with_cte_is_preserved():
56
+ raw = "WITH agg AS (SELECT id FROM t) SELECT * FROM agg"
57
+ out = strip_model_artifacts(raw)
58
+ assert out.upper().startswith("WITH")
59
+
60
+
61
+ def test_strip_returns_empty_on_garbage():
62
+ # НСт Π½ΠΈ ΠΎΠ΄Π½ΠΎΠ³ΠΎ SQL-ΠΊΠ»ΡŽΡ‡Π΅Π²ΠΎΠ³ΠΎ слова β€” ΠΎΠ±Ρ€Π΅Π·Π°Ρ‚ΡŒ Π½Π΅Ρ‡Π΅Π³ΠΎ, Π½ΠΎ ΠΈ пустого
63
+ # ΠΎΡ‚Π²Π΅Ρ‚Π° модСль Π΅Ρ‰Ρ‘ Π½Π΅ Π½Π°Π³Π΅Π½Π΅Ρ€ΠΈΠ»Π°: Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅ΠΌ ΠΊΠ°ΠΊ Π΅ΡΡ‚ΡŒ, валидация
64
+ # отсССт дальшС ΠΏΠΎ ΠΏΠ°ΠΉΠΏΠ»Π°ΠΉΠ½Ρƒ.
65
+ raw = "просто тСкст Π±Π΅Π· запроса"
66
+ assert strip_model_artifacts(raw) == "просто тСкст Π±Π΅Π· запроса"
67
+
68
+
69
+ # ──────────────────────────────────────────────────────────────────────
70
+ # is_valid_sql
71
+ # ──────────────────────────────────────────────────────────────────────
72
+
73
+ def test_valid_select():
74
  assert is_valid_sql("SELECT * FROM students WHERE id = 1")
75
 
76
 
77
+ def test_valid_with_cte():
78
+ assert is_valid_sql("WITH x AS (SELECT id FROM t) SELECT * FROM x")
79
+
80
+
81
+ def test_invalid_garbage():
82
  assert not is_valid_sql("SELEC * FRM where")
83
 
84
 
85
+ def test_invalid_empty():
86
+ assert not is_valid_sql("")
87
+ assert not is_valid_sql(" ")
88
+
89
+
90
+ # ──────────────────────────────────────────────────────────────────────
91
+ # is_select_only β€” guardrail
92
+ # ──────────────────────────────────────────────────────────────────────
93
+
94
+ def test_select_passes_guardrail():
95
+ assert is_select_only("SELECT id FROM t")
96
+
97
+
98
+ def test_with_cte_passes_guardrail():
99
+ assert is_select_only("WITH x AS (SELECT id FROM t) SELECT * FROM x")
100
+
101
+
102
+ def test_drop_table_blocked():
103
+ assert not is_select_only("DROP TABLE users")
104
+
105
+
106
+ def test_delete_blocked():
107
+ assert not is_select_only("DELETE FROM users WHERE id = 1")
108
+
109
+
110
+ def test_update_blocked():
111
+ assert not is_select_only("UPDATE users SET name = 'a' WHERE id = 1")
112
+
113
+
114
+ def test_insert_blocked():
115
+ assert not is_select_only("INSERT INTO users (id, name) VALUES (1, 'a')")
116
+
117
+
118
+ def test_empty_blocked():
119
+ assert not is_select_only("")
120
+ assert not is_select_only(" ")
121
+
122
+
123
+ def test_invalid_sql_blocked_by_guardrail():
124
+ # На Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΠΎΠΉ строкС is_select_only Π΄ΠΎΠ»ΠΆΠ΅Π½ чСстно Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Ρ‚ΡŒ False,
125
+ # Π° Π½Π΅ ΠΏΠ°Π΄Π°Ρ‚ΡŒ с ΠΈΡΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ΠΌ.
126
+ assert not is_select_only("not a sql at all")
127
+
128
+
129
+ # ──────────────────────────────────────────────────────────────────────
130
+ # normalize_sql
131
+ # ──────────────────────────────────────────────────────────────────────
132
+
133
+ def test_normalize_collapses_whitespace():
134
  a = "SELECT * FROM Users"
135
  b = "select * from users"
136
  assert normalize_sql(a) == normalize_sql(b)
137
 
138
 
139
+ def test_normalize_idempotent():
140
+ sql = "SELECT id FROM t WHERE x = 1"
141
+ assert normalize_sql(normalize_sql(sql)) == normalize_sql(sql)
142
+
143
+
144
+ def test_normalize_fallback_on_invalid():
145
+ # На Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΠΎΠΌ SQL функция Π½Π΅ Π΄ΠΎΠ»ΠΆΠ½Π° ΠΏΠ°Π΄Π°Ρ‚ΡŒ β€” Π΄ΠΎΠ»ΠΆΠ΅Π½ ΡΡ€Π°Π±ΠΎΡ‚Π°Ρ‚ΡŒ fallback.
146
+ out = normalize_sql("not really sql")
147
+ assert isinstance(out, str)
148
+ assert out.upper() == out # Π²Π΅Ρ€Ρ…Π½ΠΈΠΉ рСгистр сохранён
149
+
150
+
151
+ # ──────────────────────────────────────────────────────────────────────
152
+ # postprocess β€” ΠΏΠΎΠ»Π½Ρ‹ΠΉ pipeline
153
+ # ──────────────────────────────────────────────────────────────────────
154
+
155
+ def test_postprocess_extracts_from_markdown():
156
  raw = "```sql\nSELECT name FROM students WHERE group_id = 1;\nSELECT 2;\n```"
157
  out = postprocess(raw)
158
+ assert out.upper().startswith("SELECT NAME") or out.startswith("SELECT name")
159
  assert "SELECT 2" not in out
160
+
161
+
162
+ def test_postprocess_returns_empty_on_invalid():
163
+ # ВСкст Π½Π΅ содСрТит Π²Π°Π»ΠΈΠ΄Π½ΠΎΠ³ΠΎ SQL β€” pipeline Π΄ΠΎΠ»ΠΆΠ΅Π½ Π²Π΅Ρ€Π½ΡƒΡ‚ΡŒ ΠΏΡƒΡΡ‚ΡƒΡŽ строку,
164
+ # ΠΊΠ°ΠΊ описано Π² Ρ€Π°Π·Π΄Π΅Π»Π΅ 2.5 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки.
165
+ raw = "Π― Π½Π΅ ΠΌΠΎΠ³Ρƒ ΡΠ³Π΅Π½Π΅Ρ€ΠΈΡ€ΠΎΠ²Π°Ρ‚ΡŒ SQL для этого вопроса."
166
+ assert postprocess(raw) == ""
167
+
168
+
169
+ def test_postprocess_returns_empty_on_truncated():
170
+ # МодСль ΠΎΠ±ΠΎΡ€Π²Π°Π»Π° Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΡŽ Π½Π° сСрСдинС запроса β€” Π½Π΅Π²Π°Π»ΠΈΠ΄Π½Ρ‹ΠΉ синтаксис.
171
+ raw = "SELECT * FROM users WHERE"
172
+ assert postprocess(raw) == ""
173
+
174
+
175
+ def test_postprocess_keeps_valid_with_cte():
176
+ raw = "WITH agg AS (SELECT id FROM t) SELECT * FROM agg"
177
+ out = postprocess(raw)
178
+ assert out.upper().startswith("WITH")
tests/test_prompt.py CHANGED
@@ -1,8 +1,15 @@
1
- """ВСсты Π½Π° PromptBuilder."""
2
 
 
 
 
 
 
3
  from src.data.prompt import (
 
4
  SYSTEM_PROMPT,
5
  build_chat_messages,
 
6
  build_training_example,
7
  build_user_message,
8
  )
@@ -21,7 +28,7 @@ def test_chat_messages_have_system_and_user():
21
  msgs = build_chat_messages("schema", "question")
22
  assert len(msgs) == 2
23
  assert msgs[0]["role"] == "system"
24
- assert msgs[0]["content"] == SYSTEM_PROMPT
25
  assert msgs[1]["role"] == "user"
26
 
27
 
@@ -30,3 +37,51 @@ def test_training_example_has_assistant():
30
  assert len(msgs) == 3
31
  assert msgs[2]["role"] == "assistant"
32
  assert msgs[2]["content"] == "SELECT 1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ВСсты Π½Π° PromptBuilder.
2
 
3
+ ΠŸΠΎΠΊΡ€Ρ‹Π²Π°ΡŽΡ‚ ΠΊΠ°ΠΊ Π±Π°Π·ΠΎΠ²ΠΎΠ΅ Ρ„ΠΎΡ€ΠΌΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ chat-template, Ρ‚Π°ΠΊ ΠΈ ΠΎΠΏΡ†ΠΈΠΎΠ½Π°Π»ΡŒΠ½ΡƒΡŽ
4
+ ΠΈΠ½Ρ‚Π΅Π³Ρ€Π°Ρ†ΠΈΡŽ BusinessVocabulary Π² систСмноС сообщСниС (Ρ€Π°Π·Π΄Π΅Π» 3.6 Π’ΠšΠ ).
5
+ """
6
+
7
+ from src.business.vocabulary import BusinessVocabulary
8
  from src.data.prompt import (
9
+ BASE_SYSTEM_PROMPT,
10
  SYSTEM_PROMPT,
11
  build_chat_messages,
12
+ build_system_message,
13
  build_training_example,
14
  build_user_message,
15
  )
 
28
  msgs = build_chat_messages("schema", "question")
29
  assert len(msgs) == 2
30
  assert msgs[0]["role"] == "system"
31
+ assert msgs[0]["content"] == BASE_SYSTEM_PROMPT
32
  assert msgs[1]["role"] == "user"
33
 
34
 
 
37
  assert len(msgs) == 3
38
  assert msgs[2]["role"] == "assistant"
39
  assert msgs[2]["content"] == "SELECT 1"
40
+
41
+
42
+ def test_legacy_system_prompt_alias():
43
+ assert SYSTEM_PROMPT == BASE_SYSTEM_PROMPT
44
+
45
+
46
+ def test_system_message_without_vocabulary():
47
+ assert build_system_message(None) == BASE_SYSTEM_PROMPT
48
+
49
+
50
+ def test_system_message_with_empty_vocabulary():
51
+ vocab = BusinessVocabulary.empty()
52
+ assert build_system_message(vocab) == BASE_SYSTEM_PROMPT
53
+
54
+
55
+ def test_system_message_with_terms():
56
+ vocab = BusinessVocabulary(
57
+ company="ООО Ромашка",
58
+ terms={"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(orders.amount) WHERE orders.status = 'paid'"},
59
+ )
60
+ msg = build_system_message(vocab)
61
+ assert msg.startswith(BASE_SYSTEM_PROMPT)
62
+ assert "ООО Ромашка" in msg
63
+ assert "Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°" in msg
64
+ assert "SUM(orders.amount)" in msg
65
+
66
+
67
+ def test_chat_messages_with_vocabulary_keeps_user_clean():
68
+ vocab = BusinessVocabulary(
69
+ terms={"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(amount) WHERE status='paid'"},
70
+ )
71
+ msgs = build_chat_messages("schema", "Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°?", vocabulary=vocab)
72
+ assert msgs[0]["role"] == "system"
73
+ assert "SUM(amount)" in msgs[0]["content"]
74
+ assert msgs[1]["role"] == "user"
75
+ assert "SUM(amount)" not in msgs[1]["content"]
76
+ assert "Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°?" in msgs[1]["content"]
77
+
78
+
79
+ def test_training_example_with_vocabulary():
80
+ vocab = BusinessVocabulary(terms={"Ρ‚ΠΎΠΏ": "ORDER BY x DESC LIMIT 10"})
81
+ msgs = build_training_example(
82
+ "schema", "Π’ΠΎΠΏ ΠΊΠ»ΠΈΠ΅Π½Ρ‚ΠΎΠ²", "SELECT 1", vocabulary=vocab
83
+ )
84
+ assert len(msgs) == 3
85
+ assert msgs[0]["role"] == "system"
86
+ assert "ORDER BY x DESC" in msgs[0]["content"]
87
+ assert msgs[2]["role"] == "assistant"
tests/test_schema_provider.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ВСсты Π½Π° Π΅Π΄ΠΈΠ½Ρ‹ΠΉ SchemaProvider (Ρ€Π°Π·Π΄Π΅Π» 4.2 Π°ΡƒΠ΄ΠΈΡ‚Π°).
2
+
3
+ ΠŸΠΎΠΊΡ€Ρ‹Π²Π°ΡŽΡ‚ ΠΎΠ±Π΅ Ρ€Π΅Π°Π»ΠΈΠ·Π°Ρ†ΠΈΠΈ: SpiderSchemaProvider (структура PAUQ/Spider)
4
+ ΠΈ ConnectionSchemaProvider (ΠΏΡ€ΠΎΠΈΠ·Π²ΠΎΠ»ΡŒΠ½ΠΎΠ΅ ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ ΠΊ SQLite-Ρ„Π°ΠΉΠ»Ρƒ).
5
+ """
6
+
7
+ import sqlite3
8
+ from pathlib import Path
9
+
10
+ import pytest
11
+
12
+ from src.data.schema_provider import (
13
+ ColumnSchema,
14
+ ConnectionSchemaProvider,
15
+ SpiderSchemaProvider,
16
+ TableSchema,
17
+ render_tables,
18
+ )
19
+
20
+
21
+ # ──────────────────────────────────────────────────────────────────────
22
+ # Ѐикстуры
23
+ # ──────────────────────────────────────────────────────────────────────
24
+
25
+ @pytest.fixture
26
+ def spider_dir(tmp_path: Path) -> Path:
27
+ """data/databases/uni/uni.sqlite + data/databases/sales/sales.sqlite."""
28
+ for db_id in ("uni", "sales"):
29
+ (tmp_path / db_id).mkdir()
30
+ db = tmp_path / db_id / f"{db_id}.sqlite"
31
+ conn = sqlite3.connect(db)
32
+ conn.execute(f"CREATE TABLE {db_id}_t (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")
33
+ conn.execute(f"INSERT INTO {db_id}_t VALUES (1, '{db_id}-row')")
34
+ conn.commit()
35
+ conn.close()
36
+ return tmp_path
37
+
38
+
39
+ @pytest.fixture
40
+ def tiny_db(tmp_path: Path) -> Path:
41
+ db = tmp_path / "tiny.sqlite"
42
+ conn = sqlite3.connect(db)
43
+ conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
44
+ conn.executemany("INSERT INTO users VALUES (?, ?)", [(1, "Иван"), (2, "Анна")])
45
+ conn.commit()
46
+ conn.close()
47
+ return db
48
+
49
+
50
+ # ──────────────────────────────────────────────────────────────────────
51
+ # TableSchema.to_ddl
52
+ # ──────────────────────────────────────────────────────────────────────
53
+
54
+ def test_table_schema_to_ddl_from_create_sql():
55
+ t = TableSchema(
56
+ name="t",
57
+ create_sql="CREATE TABLE t (id INT PRIMARY KEY, name TEXT)",
58
+ )
59
+ assert t.to_ddl() == "CREATE TABLE t (id INT PRIMARY KEY, name TEXT);"
60
+
61
+
62
+ def test_table_schema_to_ddl_from_columns():
63
+ t = TableSchema(
64
+ name="users",
65
+ columns=[
66
+ ColumnSchema(name="id", type="INTEGER", primary_key=True, nullable=False),
67
+ ColumnSchema(name="email", type="TEXT", nullable=False),
68
+ ],
69
+ )
70
+ ddl = t.to_ddl()
71
+ assert ddl.startswith("CREATE TABLE users")
72
+ assert "id INTEGER PRIMARY KEY NOT NULL" in ddl
73
+ assert "email TEXT NOT NULL" in ddl
74
+
75
+
76
+ # ──────────────────────────────────────────────────────────────────────
77
+ # SpiderSchemaProvider
78
+ # ──────────────────────────────────────────────────────────────────────
79
+
80
+ def test_spider_lists_databases(spider_dir: Path):
81
+ p = SpiderSchemaProvider(spider_dir)
82
+ assert p.list_databases() == ["sales", "uni"]
83
+
84
+
85
+ def test_spider_db_path_resolves(spider_dir: Path):
86
+ p = SpiderSchemaProvider(spider_dir)
87
+ path = p.db_path("uni")
88
+ assert path.exists()
89
+ assert path.name == "uni.sqlite"
90
+
91
+
92
+ def test_spider_db_path_raises_on_missing(spider_dir: Path):
93
+ p = SpiderSchemaProvider(spider_dir)
94
+ with pytest.raises(FileNotFoundError):
95
+ p.db_path("nonexistent")
96
+
97
+
98
+ def test_spider_get_tables_returns_tableschema(spider_dir: Path):
99
+ p = SpiderSchemaProvider(spider_dir)
100
+ tables = p.get_tables("uni")
101
+ assert len(tables) == 1
102
+ assert isinstance(tables[0], TableSchema)
103
+ assert tables[0].name == "uni_t"
104
+
105
+
106
+ def test_spider_render_schema_has_create(spider_dir: Path):
107
+ p = SpiderSchemaProvider(spider_dir)
108
+ text = p.render_schema("uni")
109
+ assert "CREATE TABLE" in text
110
+ assert "uni_t" in text
111
+
112
+
113
+ # ──────────────────────────────────────────────────────────────────────
114
+ # ConnectionSchemaProvider
115
+ # ──────────────────────────────────────────────────────────────────────
116
+
117
+ def test_connection_lists_tables(tiny_db: Path):
118
+ p = ConnectionSchemaProvider(str(tiny_db))
119
+ assert p.list_tables() == ["users"]
120
+
121
+
122
+ def test_connection_get_tables_columns(tiny_db: Path):
123
+ p = ConnectionSchemaProvider(str(tiny_db))
124
+ tables = p.get_tables()
125
+ assert len(tables) == 1
126
+ cols = {c.name for c in tables[0].columns}
127
+ assert cols == {"id", "name"}
128
+
129
+
130
+ def test_connection_render_schema_with_samples(tiny_db: Path):
131
+ p = ConnectionSchemaProvider(str(tiny_db))
132
+ text = p.render_schema(include_samples=True)
133
+ assert "CREATE TABLE users" in text
134
+ assert "Иван" in text or "Анна" in text
135
+
136
+
137
+ def test_connection_test_connection(tiny_db: Path):
138
+ p = ConnectionSchemaProvider(str(tiny_db))
139
+ assert p.test_connection() is True
140
+
141
+
142
+ # ──────────────────────────────────────────────────────────────────────
143
+ # Π¦Π΅ΠΏΠΎΡ‡ΠΊΠ° SpiderSchemaProvider.for_database β†’ ConnectionSchemaProvider
144
+ # ──────────────────────────────────────────────────────────────────────
145
+
146
+ def test_spider_for_database_returns_connection_provider(spider_dir: Path):
147
+ p = SpiderSchemaProvider(spider_dir)
148
+ sub = p.for_database("sales")
149
+ assert isinstance(sub, ConnectionSchemaProvider)
150
+ text = sub.render_schema()
151
+ assert "sales_t" in text
152
+
153
+
154
+ # ──────────────────────────────────────────────────────────────────────
155
+ # render_tables β€” общая ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Π°
156
+ # ──────────────────────────────────────────────────────────────────────
157
+
158
+ def test_render_tables_groups_ddl_and_samples():
159
+ tables = [
160
+ TableSchema(
161
+ name="x",
162
+ columns=[ColumnSchema(name="id", type="INT")],
163
+ sample_rows=[(1,), (2,)],
164
+ ),
165
+ ]
166
+ text = render_tables(tables, include_samples=True)
167
+ assert "CREATE TABLE x" in text
168
+ assert "ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ строк" in text
169
+ assert "(1," in text and "(2," in text
170
+
171
+
172
+ def test_render_tables_no_samples():
173
+ tables = [
174
+ TableSchema(
175
+ name="x",
176
+ columns=[ColumnSchema(name="id", type="INT")],
177
+ sample_rows=[(1,)],
178
+ ),
179
+ ]
180
+ text = render_tables(tables, include_samples=False)
181
+ assert "CREATE TABLE x" in text
182
+ assert "ΠŸΡ€ΠΈΠΌΠ΅Ρ€Ρ‹ строк" not in text
tests/test_vocabulary.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ВСсты Π½Π° BusinessVocabulary (Ρ€Π°Π·Π΄Π΅Π» 3.6 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки)."""
2
+
3
+ from pathlib import Path
4
+
5
+ import pytest
6
+
7
+ from src.business.vocabulary import BusinessVocabulary
8
+
9
+
10
+ # ──────────────────────────────────────────────────────────────────────
11
+ # Π—Π°Π³Ρ€ΡƒΠ·ΠΊΠ°
12
+ # ──────────────────────────────────────────────────────────────────────
13
+
14
+ def test_empty_vocabulary_is_falsy():
15
+ vocab = BusinessVocabulary.empty()
16
+ assert not vocab
17
+ assert vocab.render_system_context() == ""
18
+
19
+
20
+ def test_from_dict_basic():
21
+ vocab = BusinessVocabulary.from_dict(
22
+ {
23
+ "company": "ООО ВСст",
24
+ "terms": {"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(amount)"},
25
+ "filters": {"paid": "status='paid'"},
26
+ "notes": ["ΠŸΠ΅Ρ€ΠΈΠΎΠ΄: 2025-01-01β€”2026-04-30"],
27
+ }
28
+ )
29
+ assert bool(vocab)
30
+ assert vocab.company == "ООО ВСст"
31
+ assert vocab.terms == {"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(amount)"}
32
+ assert vocab.filters == {"paid": "status='paid'"}
33
+ assert vocab.notes == ["ΠŸΠ΅Ρ€ΠΈΠΎΠ΄: 2025-01-01β€”2026-04-30"]
34
+
35
+
36
+ def test_from_yaml_roundtrip(tmp_path: Path):
37
+ original = BusinessVocabulary(
38
+ company="ООО Ромашка",
39
+ terms={"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(orders.amount)"},
40
+ filters={"paid_only": "orders.status='paid'"},
41
+ notes=["amount Ρ‚ΠΎΠ»ΡŒΠΊΠΎ Π² orders"],
42
+ )
43
+ path = tmp_path / "vocab.yaml"
44
+ original.save_yaml(path)
45
+
46
+ loaded = BusinessVocabulary.from_yaml(path)
47
+ assert loaded.company == original.company
48
+ assert loaded.terms == original.terms
49
+ assert loaded.filters == original.filters
50
+ assert loaded.notes == original.notes
51
+
52
+
53
+ def test_from_yaml_missing_file(tmp_path: Path):
54
+ with pytest.raises(FileNotFoundError):
55
+ BusinessVocabulary.from_yaml(tmp_path / "does_not_exist.yaml")
56
+
57
+
58
+ # ──────────────────────────────────────────────────────────────────────
59
+ # enrich_prompt β€” обратная ΡΠΎΠ²ΠΌΠ΅ΡΡ‚ΠΈΠΌΠΎΡΡ‚ΡŒ со старым Streamlit-ΠΊΠΎΠ΄ΠΎΠΌ
60
+ # ──────────────────────────────────────────────────────────────────────
61
+
62
+ def test_enrich_prompt_pass_through_when_empty():
63
+ vocab = BusinessVocabulary.empty()
64
+ assert vocab.enrich_prompt("Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°?") == "Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°?"
65
+
66
+
67
+ def test_enrich_prompt_adds_term_definition():
68
+ vocab = BusinessVocabulary(
69
+ terms={"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(amount) WHERE status='paid'"},
70
+ )
71
+ enriched = vocab.enrich_prompt("Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ° Π·Π° ΡΠ½Π²Π°Ρ€ΡŒ?")
72
+ assert "Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°" in enriched
73
+ assert "SUM(amount)" in enriched
74
+ assert "Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ° Π·Π° ΡΠ½Π²Π°Ρ€ΡŒ?" in enriched
75
+
76
+
77
+ def test_enrich_prompt_case_insensitive_match():
78
+ vocab = BusinessVocabulary(terms={"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(amount)"})
79
+ enriched = vocab.enrich_prompt("Π’Π«Π Π£Π§ΠšΠ Π² этом мСсяцС?")
80
+ assert "SUM(amount)" in enriched
81
+
82
+
83
+ def test_enrich_prompt_ignores_unrelated_terms():
84
+ vocab = BusinessVocabulary(
85
+ terms={
86
+ "Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(amount)",
87
+ "Ρ‚ΠΎΠΏ ΠΊΠ»ΠΈΠ΅Π½Ρ‚ΠΎΠ²": "ORDER BY SUM(amount) DESC",
88
+ },
89
+ )
90
+ enriched = vocab.enrich_prompt("Какая Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°?")
91
+ # Π Π΅Π»Π΅Π²Π°Π½Ρ‚Π½Ρ‹ΠΉ Ρ‚Π΅Ρ€ΠΌΠΈΠ½ подмСшан, посторонний β€” Π½Π΅Ρ‚
92
+ assert "SUM(amount)" in enriched
93
+ assert "ORDER BY SUM(amount) DESC" not in enriched
94
+
95
+
96
+ # ──────────────────────────────────────────────────────────────────────
97
+ # render_system_context β€” Ρ‚ΠΎ, Ρ‡Ρ‚ΠΎ ΠΏΠΎΠ΄ΠΌΠ΅ΡˆΠΈΠ²Π°Π΅Ρ‚ΡΡ Π² system-сообщСниС
98
+ # ──────────────────────────────────────────────────────────────────────
99
+
100
+ def test_render_system_context_empty():
101
+ assert BusinessVocabulary.empty().render_system_context() == ""
102
+
103
+
104
+ def test_render_system_context_with_company():
105
+ vocab = BusinessVocabulary(
106
+ company="ООО Ромашка",
107
+ terms={"Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°": "SUM(amount)"},
108
+ )
109
+ ctx = vocab.render_system_context()
110
+ assert "ООО Ромашка" in ctx
111
+ assert "Π²Ρ‹Ρ€ΡƒΡ‡ΠΊΠ°" in ctx
112
+ assert "SUM(amount)" in ctx
113
+
114
+
115
+ def test_render_system_context_with_filters_and_notes():
116
+ vocab = BusinessVocabulary(
117
+ filters={"Ρ‚ΠΎΠ»ΡŒΠΊΠΎ_ΠΎΠΏΠ»Π°Ρ‡Π΅Π½Π½Ρ‹Π΅": "orders.status='paid'"},
118
+ notes=["amount Ρ‚ΠΎΠ»ΡŒΠΊΠΎ Π² orders"],
119
+ )
120
+ ctx = vocab.render_system_context()
121
+ assert "Ρ‚ΠΎΠ»ΡŒΠΊΠΎ_ΠΎΠΏΠ»Π°Ρ‡Π΅Π½Π½Ρ‹Π΅" in ctx
122
+ assert "orders.status='paid'" in ctx
123
+ assert "amount Ρ‚ΠΎΠ»ΡŒΠΊΠΎ Π² orders" in ctx