VladGeekPro Copilot commited on
Commit
45b18ac
·
1 Parent(s): 81b7609

CreatedDBAgent

Browse files

Co-authored-by: Copilot <copilot@github.com>

Files changed (3) hide show
  1. app.py +22 -0
  2. requirements.txt +1 -0
  3. sql_generator.py +181 -0
app.py CHANGED
@@ -24,6 +24,7 @@ from extractors import (
24
  ExpenseUserExtractor,
25
  ExpenseAmountExtractor,
26
  )
 
27
 
28
 
29
  # HuggingFace Token (если нужен для моделей)
@@ -401,6 +402,12 @@ def parse_context(raw: str | None) -> dict[str, Any]:
401
  return {}
402
 
403
 
 
 
 
 
 
 
404
  # ============================================================================
405
  # ENDPOINTS
406
  # ============================================================================
@@ -413,6 +420,7 @@ def index():
413
  "message": "Voice processing API is running",
414
  "endpoints": {
415
  "POST /process-audio": "Process audio file",
 
416
  "GET /health": "Health check",
417
  "GET /test-data": "Run text-only extraction tests"
418
  }
@@ -495,5 +503,19 @@ def process_audio():
495
  os.unlink(temp_path)
496
 
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  if __name__ == "__main__":
499
  app.run(host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
 
24
  ExpenseUserExtractor,
25
  ExpenseAmountExtractor,
26
  )
27
+ from sql_generator import generate_sql
28
 
29
 
30
  # HuggingFace Token (если нужен для моделей)
 
402
  return {}
403
 
404
 
405
+ def parse_json_payload() -> dict[str, Any]:
406
+ """Возвращает JSON payload из входящего запроса."""
407
+ payload = request.get_json(silent=True)
408
+ return payload if isinstance(payload, dict) else {}
409
+
410
+
411
  # ============================================================================
412
  # ENDPOINTS
413
  # ============================================================================
 
420
  "message": "Voice processing API is running",
421
  "endpoints": {
422
  "POST /process-audio": "Process audio file",
423
+ "POST /generate-sql": "Generate SQLite SELECT query from natural language",
424
  "GET /health": "Health check",
425
  "GET /test-data": "Run text-only extraction tests"
426
  }
 
503
  os.unlink(temp_path)
504
 
505
 
506
+ @app.post("/generate-sql")
507
+ def generate_sql_endpoint():
508
+ """Генерирует SQL по текстовому запросу и схеме БД."""
509
+ payload = parse_json_payload()
510
+ query = payload.get("query") or payload.get("text") or ""
511
+ limit = payload.get("limit") or 200
512
+
513
+ try:
514
+ sql = generate_sql(question=query, limit=int(limit))
515
+ return jsonify({"sql": sql})
516
+ except Exception as exception:
517
+ return jsonify({"status": "error", "message": str(exception)}), 422
518
+
519
+
520
  if __name__ == "__main__":
521
  app.run(host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
requirements.txt CHANGED
@@ -9,3 +9,4 @@ python-dateutil
9
  iuliia
10
  torch
11
  scikit-learn
 
 
9
  iuliia
10
  torch
11
  scikit-learn
12
+ sentencepiece
sql_generator.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+
10
+ DEFAULT_DB_SCHEMA = """CREATE TABLE \"migrations\" (\"id\" integer primary key autoincrement not null, \"migration\" varchar not null, \"batch\" integer not null);
11
+
12
+ CREATE TABLE \"users\" (\"id\" integer primary key autoincrement not null, \"image\" varchar, \"name\" varchar not null, \"email\" varchar not null, \"email_verified_at\" datetime, \"password\" varchar not null, \"widget_preferences\" text, \"remember_token\" varchar, \"created_at\" datetime, \"updated_at\" datetime);
13
+
14
+ CREATE UNIQUE INDEX \"users_email_unique\" on \"users\" (\"email\");
15
+
16
+ CREATE TABLE \"password_reset_tokens\" (\"email\" varchar not null, \"token\" varchar not null, \"created_at\" datetime, primary key (\"email\"));
17
+
18
+ CREATE TABLE \"sessions\" (\"id\" varchar not null, \"user_id\" integer, \"ip_address\" varchar, \"user_agent\" text, \"payload\" text not null, \"last_activity\" integer not null, primary key (\"id\"));
19
+
20
+ CREATE INDEX \"sessions_user_id_index\" on \"sessions\" (\"user_id\");
21
+
22
+ CREATE INDEX \"sessions_last_activity_index\" on \"sessions\" (\"last_activity\");
23
+
24
+ CREATE TABLE \"cache\" (\"key\" varchar not null, \"value\" text not null, \"expiration\" integer not null, primary key (\"key\"));
25
+
26
+ CREATE TABLE \"cache_locks\" (\"key\" varchar not null, \"owner\" varchar not null, \"expiration\" integer not null, primary key (\"key\"));
27
+
28
+ CREATE TABLE \"jobs\" (\"id\" integer primary key autoincrement not null, \"queue\" varchar not null, \"payload\" text not null, \"attempts\" integer not null, \"reserved_at\" integer, \"available_at\" integer not null, \"created_at\" integer not null);
29
+
30
+ CREATE INDEX \"jobs_queue_index\" on \"jobs\" (\"queue\");
31
+
32
+ CREATE TABLE \"job_batches\" (\"id\" varchar not null, \"name\" varchar not null, \"total_jobs\" integer not null, \"pending_jobs\" integer not null, \"failed_jobs\" integer not null, \"failed_job_ids\" text not null, \"options\" text, \"cancelled_at\" integer, \"created_at\" integer not null, \"finished_at\" integer, primary key (\"id\"));
33
+
34
+ CREATE TABLE \"failed_jobs\" (\"id\" integer primary key autoincrement not null, \"uuid\" varchar not null, \"connection\" text not null, \"queue\" text not null, \"payload\" text not null, \"exception\" text not null, \"failed_at\" datetime not null default CURRENT_TIMESTAMP);
35
+
36
+ CREATE UNIQUE INDEX \"failed_jobs_uuid_unique\" on \"failed_jobs\" (\"uuid\");
37
+
38
+ CREATE TABLE \"categories\" (\"id\" integer primary key autoincrement not null, \"name\" varchar not null, \"slug\" varchar not null, \"image\" varchar, \"notes\" text not null, \"created_at\" datetime, \"updated_at\" datetime);
39
+
40
+ CREATE UNIQUE INDEX \"categories_slug_unique\" on \"categories\" (\"slug\");
41
+
42
+ CREATE TABLE \"suppliers\" (\"id\" integer primary key autoincrement not null, \"name\" varchar not null, \"slug\" varchar not null, \"image\" varchar, \"category_id\" integer not null, \"created_at\" datetime, \"updated_at\" datetime, foreign key(\"category_id\") references \"categories\"(\"id\") on delete cascade);
43
+
44
+ CREATE UNIQUE INDEX \"suppliers_slug_unique\" on \"suppliers\" (\"slug\");
45
+
46
+ CREATE TABLE \"expenses\" (\"id\" integer primary key autoincrement not null, \"user_id\" integer not null, \"date\" date not null, \"category_id\" integer not null, \"supplier_id\" integer not null, \"sum\" numeric not null, \"notes\" text, \"created_at\" datetime, \"updated_at\" datetime, foreign key(\"user_id\") references \"users\"(\"id\") on delete cascade, foreign key(\"category_id\") references \"categories\"(\"id\") on delete set null, foreign key(\"supplier_id\") references \"suppliers\"(\"id\") on delete set null);
47
+
48
+ CREATE INDEX \"expenses_date_index\" on \"expenses\" (\"date\");
49
+
50
+ CREATE INDEX \"expenses_user_id_date_index\" on \"expenses\" (\"user_id\", \"date\");
51
+
52
+ CREATE TABLE \"overpayments\" (\"id\" integer primary key autoincrement not null, \"user_id\" integer not null, \"sum\" numeric not null, \"notes\" text not null, \"created_at\" datetime, \"updated_at\" datetime, foreign key(\"user_id\") references \"users\"(\"id\") on delete cascade);
53
+
54
+ CREATE INDEX \"overpayments_created_at_index\" on \"overpayments\" (\"created_at\");
55
+
56
+ CREATE TABLE \"debts\" (\"id\" integer primary key autoincrement not null, \"date\" date not null, \"user_id\" integer, \"debt_sum\" numeric not null default '0', \"overpayment_id\" integer, \"notes\" text not null, \"payment_status\" varchar check (\"payment_status\" in ('unpaid', 'partial', 'paid')) not null default 'unpaid', \"partial_sum\" numeric not null default '0', \"date_paid\" date, \"created_at\" datetime, \"updated_at\" datetime, foreign key(\"user_id\") references \"users\"(\"id\") on delete cascade, foreign key(\"overpayment_id\") references \"overpayments\"(\"id\") on delete cascade);
57
+
58
+ CREATE INDEX \"debts_payment_status_date_index\" on \"debts\" (\"payment_status\", \"date\");
59
+
60
+ CREATE UNIQUE INDEX \"debts_date_unique\" on \"debts\" (\"date\");
61
+
62
+ CREATE TABLE \"notifications\" (\"id\" varchar not null, \"type\" varchar not null, \"notifiable_type\" varchar not null, \"notifiable_id\" integer not null, \"data\" text not null, \"read_at\" datetime, \"created_at\" datetime, \"updated_at\" datetime, primary key (\"id\"));
63
+
64
+ CREATE INDEX \"notifications_notifiable_type_notifiable_id_index\" on \"notifications\" (\"notifiable_type\", \"notifiable_id\");
65
+
66
+ CREATE TABLE \"expense_change_requests\" (\"id\" integer primary key autoincrement not null, \"expense_id\" integer, \"user_id\" integer not null, \"action_type\" varchar check (\"action_type\" in ('create', 'edit', 'delete')) not null, \"current_date\" date, \"current_user_id\" integer, \"current_category_id\" integer, \"current_supplier_id\" integer, \"current_sum\" numeric, \"current_notes\" text, \"requested_date\" date, \"requested_user_id\" integer, \"requested_category_id\" integer, \"requested_supplier_id\" integer, \"requested_sum\" numeric, \"requested_notes\" text, \"notes\" text not null, \"status\" varchar check (\"status\" in ('pending', 'rejected', 'completed')) not null default 'pending', \"applied_at\" datetime, \"created_at\" datetime, \"updated_at\" datetime, foreign key(\"expense_id\") references \"expenses\"(\"id\") on delete set null, foreign key(\"user_id\") references \"users\"(\"id\") on delete cascade, foreign key(\"current_user_id\") references \"users\"(\"id\") on delete set null, foreign key(\"current_category_id\") references \"categories\"(\"id\") on delete set null, foreign key(\"current_supplier_id\") references \"suppliers\"(\"id\") on delete set null, foreign key(\"requested_user_id\") references \"users\"(\"id\") on delete set null, foreign key(\"requested_category_id\") references \"categories\"(\"id\") on delete set null, foreign key(\"requested_supplier_id\") references \"suppliers\"(\"id\") on delete set null);
67
+
68
+ CREATE INDEX \"expense_change_requests_status_created_at_index\" on \"expense_change_requests\" (\"status\", \"created_at\");
69
+
70
+ CREATE INDEX \"expense_change_requests_expense_id_status_index\" on \"expense_change_requests\" (\"expense_id\", \"status\");
71
+
72
+ CREATE INDEX \"expense_change_requests_user_id_index\" on \"expense_change_requests\" (\"user_id\");
73
+
74
+ CREATE TABLE \"expense_change_request_votes\" (\"id\" integer primary key autoincrement not null, \"expense_change_request_id\" integer not null, \"user_id\" integer not null, \"vote\" varchar check (\"vote\" in ('approved', 'rejected')) not null, \"notes\" text, \"created_at\" datetime, \"updated_at\" datetime, foreign key(\"expense_change_request_id\") references \"expense_change_requests\"(\"id\") on delete cascade, foreign key(\"user_id\") references \"users\"(\"id\") on delete cascade);
75
+
76
+ CREATE UNIQUE INDEX \"unique_vote_per_user\" on \"expense_change_request_votes\" (\"expense_change_request_id\", \"user_id\");
77
+
78
+ CREATE INDEX \"expense_change_request_votes_user_id_vote_index\" on \"expense_change_request_votes\" (\"user_id\", \"vote\");
79
+
80
+ CREATE TABLE \"paid_debts\" (\"id\" integer primary key autoincrement not null, \"created_at\" datetime, \"updated_at\" datetime, \"debt_id\" integer not null, \"changed_debt_date\" date not null, \"paid_by_user_id\" integer not null, \"payment_status\" varchar check (\"payment_status\" in ('partial', 'paid')) not null, \"paid_sum\" numeric not null, foreign key(\"debt_id\") references \"debts\"(\"id\") on delete cascade, foreign key(\"paid_by_user_id\") references \"users\"(\"id\") on delete cascade);
81
+
82
+ CREATE INDEX \"paid_debts_debt_id_index\" on \"paid_debts\" (\"debt_id\");
83
+
84
+ CREATE INDEX \"paid_debts_paid_by_user_id_index\" on \"paid_debts\" (\"paid_by_user_id\");"""
85
+
86
+ _SQL_GENERATOR: Any | None = None
87
+
88
+
89
+ @dataclass(frozen=True)
90
+ class SqlGenerationRequest:
91
+ question: str
92
+ limit: int = 200
93
+
94
+
95
+ def _get_sql_generator() -> Any:
96
+ global _SQL_GENERATOR
97
+
98
+ if _SQL_GENERATOR is None:
99
+ from transformers import pipeline
100
+
101
+ model_id = os.getenv("SQL_MODEL", "google/flan-t5-base")
102
+ _SQL_GENERATOR = pipeline(
103
+ task="text2text-generation",
104
+ model=model_id,
105
+ tokenizer=model_id,
106
+ device=-1,
107
+ torch_dtype=torch.float32,
108
+ )
109
+
110
+ return _SQL_GENERATOR
111
+
112
+
113
+ def _build_prompt(payload: SqlGenerationRequest) -> str:
114
+ return (
115
+ "You translate user requests into SQLite SELECT queries. "
116
+ "Return only SQL without explanations. "
117
+ "Use only tables and columns from the schema. "
118
+ "Never generate INSERT, UPDATE, DELETE, DROP, ALTER, PRAGMA, ATTACH or CREATE. "
119
+ "Prefer explicit JOIN conditions using foreign keys from the schema. "
120
+ f"Add LIMIT {payload.limit} when the query is not an aggregate result.\n\n"
121
+ f"Schema:\n{DEFAULT_DB_SCHEMA}\n\n"
122
+ f"User request:\n{payload.question}\n\n"
123
+ "SQL:"
124
+ )
125
+
126
+
127
+ def _normalize_sql(raw_sql: str, limit: int) -> str:
128
+ sql = (raw_sql or "").strip()
129
+ if not sql:
130
+ raise ValueError("SQL model returned an empty result.")
131
+
132
+ if "```" in sql:
133
+ parts = [part.strip() for part in sql.split("```") if part.strip()]
134
+ sql = parts[-1]
135
+
136
+ upper_sql = sql.upper()
137
+ sql_start = upper_sql.find("SELECT")
138
+ if sql_start == -1:
139
+ raise ValueError("Generated SQL is not a SELECT query.")
140
+
141
+ sql = sql[sql_start:]
142
+ if ";" in sql:
143
+ sql = sql.split(";", 1)[0].strip()
144
+
145
+ upper_sql = sql.upper()
146
+ forbidden = ("INSERT ", "UPDATE ", "DELETE ", "DROP ", "ALTER ", "PRAGMA ", "ATTACH ", "CREATE ", "REPLACE ")
147
+ if any(keyword in upper_sql for keyword in forbidden):
148
+ raise ValueError("Generated SQL contains forbidden statements.")
149
+
150
+ if not upper_sql.startswith("SELECT "):
151
+ raise ValueError("Only SELECT queries are allowed.")
152
+
153
+ aggregate_markers = ("COUNT(", "SUM(", "AVG(", "MIN(", "MAX(")
154
+ has_limit = " LIMIT " in upper_sql
155
+ if not has_limit and not any(marker in upper_sql for marker in aggregate_markers):
156
+ sql = f"{sql} LIMIT {limit}"
157
+
158
+ return sql
159
+
160
+
161
+ def generate_sql(question: str, limit: int = 200) -> str:
162
+ clean_question = (question or "").strip()
163
+ if not clean_question:
164
+ raise ValueError("Field 'query' is required.")
165
+
166
+ payload = SqlGenerationRequest(
167
+ question=clean_question,
168
+ limit=limit,
169
+ )
170
+
171
+ generator = _get_sql_generator()
172
+ prompt = _build_prompt(payload)
173
+ result = generator(
174
+ prompt,
175
+ max_new_tokens=256,
176
+ do_sample=False,
177
+ truncation=True,
178
+ )
179
+
180
+ generated_text = result[0].get("generated_text", "") if result else ""
181
+ return _normalize_sql(generated_text, limit=payload.limit)