ryanwang-trt commited on
Commit
7493570
·
1 Parent(s): 87e5f2c

Deploy SQLator backend

Browse files
Files changed (5) hide show
  1. .dockerignore +29 -0
  2. Dockerfile +27 -0
  3. app.py +356 -0
  4. config.py +22 -0
  5. requirements.txt +9 -0
.dockerignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .gitignore
3
+ .gitattributes
4
+ .github
5
+ .claude
6
+ .vscode
7
+ .idea
8
+
9
+ __pycache__
10
+ *.pyc
11
+ *.pyo
12
+ *.pyd
13
+ .pytest_cache
14
+ .mypy_cache
15
+ .ruff_cache
16
+
17
+ .venv
18
+ venv
19
+ env
20
+ ENV
21
+
22
+ chrome-extension
23
+ data
24
+ models
25
+ *.log
26
+ demo.gif
27
+ README.md
28
+ Dockerfile
29
+ .dockerignore
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ libgomp1 \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ RUN useradd -m -u 1000 user
8
+ USER user
9
+
10
+ ENV HOME=/home/user \
11
+ PATH=/home/user/.local/bin:$PATH \
12
+ HF_HOME=/home/user/.cache/huggingface \
13
+ PYTHONUNBUFFERED=1
14
+
15
+ WORKDIR /home/user/app
16
+
17
+ COPY --chown=user requirements.txt .
18
+ RUN pip install --no-cache-dir --user \
19
+ --extra-index-url https://download.pytorch.org/whl/cpu \
20
+ -r requirements.txt \
21
+ && pip install --no-cache-dir --user gunicorn
22
+
23
+ COPY --chown=user . .
24
+
25
+ EXPOSE 7860
26
+
27
+ CMD ["gunicorn", "-w", "1", "-t", "300", "-b", "0.0.0.0:7860", "app:app"]
app.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ from flask import Flask, request, render_template_string, jsonify
5
+ from flask_cors import CORS
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ from config import MODEL_PATH, HF_MODEL_ID, MAX_INPUT_LENGTH, MAX_OUTPUT_LENGTH, NUM_BEAMS, PROMPT_TEMPLATE, MAX_QUESTION_LENGTH, MAX_SCHEMA_LENGTH
8
+ from schema import truncate_schema
9
+
10
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
11
+ log = logging.getLogger(__name__)
12
+
13
+ app = Flask(__name__)
14
+ CORS(app)
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ tokenizer = None
19
+ model = None
20
+
21
+ def get_model():
22
+ global tokenizer, model
23
+ if model is None:
24
+ if os.path.exists(MODEL_PATH):
25
+ source = MODEL_PATH
26
+ else:
27
+ log.info(f"Local model not found at '{MODEL_PATH}', downloading from HuggingFace: {HF_MODEL_ID}")
28
+ source = HF_MODEL_ID
29
+ tokenizer = AutoTokenizer.from_pretrained(source)
30
+ model = AutoModelForSeq2SeqLM.from_pretrained(source)
31
+ model = model.to(device)
32
+ model.eval()
33
+ log.info(f"Model loaded from {source} on {device}")
34
+ return tokenizer, model
35
+
36
+ def predict(question, db_id="unknown", schema="unknown"):
37
+ schema = truncate_schema(schema, MAX_SCHEMA_LENGTH)
38
+ input_text = PROMPT_TEMPLATE.format(db_id=db_id, schema=schema, question=question)
39
+ tokenizer, model = get_model()
40
+ tokenized_input = tokenizer(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt")
41
+ tokenized_outputs = model.generate(
42
+ input_ids=tokenized_input["input_ids"].to(device),
43
+ attention_mask=tokenized_input["attention_mask"].to(device),
44
+ max_length=MAX_OUTPUT_LENGTH,
45
+ num_beams=NUM_BEAMS,
46
+ )
47
+ return tokenizer.decode(tokenized_outputs[0], skip_special_tokens=True)
48
+
49
+ HTML = """
50
+ <!DOCTYPE html>
51
+ <html>
52
+ <head>
53
+ <title>SQLator — Natural Language to SQL</title>
54
+ <link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&family=DM+Sans:wght@400;500;700&display=swap" rel="stylesheet">
55
+ <style>
56
+ * { margin: 0; padding: 0; box-sizing: border-box; }
57
+
58
+ body {
59
+ font-family: 'DM Sans', sans-serif;
60
+ min-height: 100vh;
61
+ background: #0a0a0f;
62
+ color: #e0e0e0;
63
+ display: flex;
64
+ align-items: center;
65
+ justify-content: center;
66
+ overflow: hidden;
67
+ }
68
+
69
+ /* animated background grid */
70
+ body::before {
71
+ content: '';
72
+ position: fixed;
73
+ top: 0; left: 0; right: 0; bottom: 0;
74
+ background-image:
75
+ linear-gradient(rgba(56, 189, 248, 0.03) 1px, transparent 1px),
76
+ linear-gradient(90deg, rgba(56, 189, 248, 0.03) 1px, transparent 1px);
77
+ background-size: 60px 60px;
78
+ z-index: 0;
79
+ }
80
+
81
+ /* glow orb */
82
+ body::after {
83
+ content: '';
84
+ position: fixed;
85
+ top: -200px; right: -200px;
86
+ width: 600px; height: 600px;
87
+ background: radial-gradient(circle, rgba(56, 189, 248, 0.08), transparent 70%);
88
+ border-radius: 50%;
89
+ z-index: 0;
90
+ }
91
+
92
+ .container {
93
+ position: relative;
94
+ z-index: 1;
95
+ width: 100%;
96
+ max-width: 680px;
97
+ padding: 20px;
98
+ }
99
+
100
+ .badge {
101
+ display: inline-block;
102
+ padding: 6px 14px;
103
+ background: rgba(56, 189, 248, 0.1);
104
+ border: 1px solid rgba(56, 189, 248, 0.2);
105
+ border-radius: 100px;
106
+ font-size: 12px;
107
+ font-weight: 500;
108
+ color: #38bdf8;
109
+ letter-spacing: 1.5px;
110
+ text-transform: uppercase;
111
+ margin-bottom: 20px;
112
+ }
113
+
114
+ h1 {
115
+ font-family: 'JetBrains Mono', monospace;
116
+ font-size: 42px;
117
+ font-weight: 700;
118
+ color: #ffffff;
119
+ line-height: 1.1;
120
+ margin-bottom: 8px;
121
+ }
122
+
123
+ h1 span {
124
+ background: linear-gradient(135deg, #38bdf8, #818cf8);
125
+ -webkit-background-clip: text;
126
+ -webkit-text-fill-color: transparent;
127
+ }
128
+
129
+ .subtitle {
130
+ color: #6b7280;
131
+ font-size: 15px;
132
+ margin-bottom: 40px;
133
+ }
134
+
135
+ .card {
136
+ background: rgba(255, 255, 255, 0.03);
137
+ border: 1px solid rgba(255, 255, 255, 0.06);
138
+ border-radius: 16px;
139
+ padding: 32px;
140
+ backdrop-filter: blur(20px);
141
+ }
142
+
143
+ label {
144
+ display: block;
145
+ font-size: 13px;
146
+ font-weight: 500;
147
+ color: #9ca3af;
148
+ margin-bottom: 8px;
149
+ letter-spacing: 0.5px;
150
+ }
151
+
152
+ input[type=text] {
153
+ width: 100%;
154
+ padding: 14px 16px;
155
+ background: rgba(0, 0, 0, 0.4);
156
+ border: 1px solid rgba(255, 255, 255, 0.08);
157
+ border-radius: 10px;
158
+ color: #f0f0f0;
159
+ font-family: 'DM Sans', sans-serif;
160
+ font-size: 15px;
161
+ outline: none;
162
+ transition: border-color 0.2s;
163
+ margin-bottom: 20px;
164
+ }
165
+
166
+ input[type=text]:focus, textarea:focus {
167
+ border-color: rgba(56, 189, 248, 0.4);
168
+ }
169
+
170
+ input[type=text]::placeholder, textarea::placeholder {
171
+ color: #4b5563;
172
+ }
173
+
174
+ textarea {
175
+ width: 100%;
176
+ padding: 14px 16px;
177
+ background: rgba(0, 0, 0, 0.4);
178
+ border: 1px solid rgba(255, 255, 255, 0.08);
179
+ border-radius: 10px;
180
+ color: #f0f0f0;
181
+ font-family: 'JetBrains Mono', monospace;
182
+ font-size: 13px;
183
+ outline: none;
184
+ transition: border-color 0.2s;
185
+ margin-bottom: 20px;
186
+ resize: vertical;
187
+ }
188
+
189
+ button {
190
+ width: 100%;
191
+ padding: 14px;
192
+ background: linear-gradient(135deg, #38bdf8, #818cf8);
193
+ color: #fff;
194
+ font-family: 'DM Sans', sans-serif;
195
+ font-size: 15px;
196
+ font-weight: 600;
197
+ border: none;
198
+ border-radius: 10px;
199
+ cursor: pointer;
200
+ transition: opacity 0.2s, transform 0.1s;
201
+ letter-spacing: 0.3px;
202
+ }
203
+
204
+ button:hover { opacity: 0.9; }
205
+ button:active { transform: scale(0.98); }
206
+
207
+ .result {
208
+ margin-top: 28px;
209
+ padding-top: 28px;
210
+ border-top: 1px solid rgba(255, 255, 255, 0.06);
211
+ }
212
+
213
+ .result-label {
214
+ font-size: 12px;
215
+ font-weight: 500;
216
+ color: #6b7280;
217
+ letter-spacing: 1px;
218
+ text-transform: uppercase;
219
+ margin-bottom: 6px;
220
+ }
221
+
222
+ .result-question {
223
+ color: #d1d5db;
224
+ font-size: 15px;
225
+ margin-bottom: 16px;
226
+ }
227
+
228
+ .sql-output {
229
+ background: rgba(0, 0, 0, 0.5);
230
+ border: 1px solid rgba(56, 189, 248, 0.15);
231
+ border-radius: 10px;
232
+ padding: 16px 20px;
233
+ font-family: 'JetBrains Mono', monospace;
234
+ font-size: 14px;
235
+ color: #38bdf8;
236
+ line-height: 1.6;
237
+ overflow-x: auto;
238
+ }
239
+
240
+ .footer {
241
+ text-align: center;
242
+ margin-top: 32px;
243
+ font-size: 12px;
244
+ color: #374151;
245
+ }
246
+
247
+ .footer a {
248
+ color: #4b5563;
249
+ text-decoration: none;
250
+ }
251
+
252
+ /* fade in animation */
253
+ .container { animation: fadeUp 0.6s ease-out; }
254
+
255
+ @keyframes fadeUp {
256
+ from { opacity: 0; transform: translateY(20px); }
257
+ to { opacity: 1; transform: translateY(0); }
258
+ }
259
+ </style>
260
+ </head>
261
+ <body>
262
+ <div class="container">
263
+ <div class="badge">Fine-tuned CodeT5+ Model</div>
264
+ <h1>SQL<span>ator</span></h1>
265
+ <p class="subtitle">Ask a question in plain English. Get a SQL query back.</p>
266
+
267
+ <div class="card">
268
+ <form method="POST">
269
+ <label>YOUR QUESTION</label>
270
+ <input type="text" name="question" placeholder="e.g. how many employees are in each department" value="{{ question or '' }}" autofocus>
271
+
272
+ <label>DATABASE (OPTIONAL)</label>
273
+ <input type="text" name="db_id" placeholder="e.g. concert_singer" value="{{ db_id or '' }}">
274
+
275
+ <label>SCHEMA (OPTIONAL)</label>
276
+ <textarea name="schema" rows="3" placeholder="e.g. singer(singer_id, name, country, age), concert(concert_id, concert_name, theme)">{{ schema or '' }}</textarea>
277
+
278
+ <button type="submit">Generate SQL →</button>
279
+ </form>
280
+
281
+ {% if error %}
282
+ <div class="result">
283
+ <div style="color: #f87171; font-size: 14px;">{{ error }}</div>
284
+ </div>
285
+ {% endif %}
286
+
287
+ {% if sql %}
288
+ <div class="result">
289
+ <div class="result-label">Input</div>
290
+ <div class="result-question">{{ question }}</div>
291
+
292
+ <div class="result-label">Generated SQL</div>
293
+ <div class="sql-output">{{ sql }}</div>
294
+ </div>
295
+ {% endif %}
296
+ </div>
297
+
298
+ <div class="footer">
299
+ Built with CodeT5+ 220M + PyTorch — <a href="https://github.com">View on GitHub</a>
300
+ </div>
301
+ </div>
302
+ </body>
303
+ </html>
304
+ """
305
+
306
+ @app.route("/health", methods=["GET"])
307
+ def health():
308
+ return jsonify({"status": "ok"})
309
+
310
+
311
+ @app.route("/predict", methods=["POST"])
312
+ def predict_api():
313
+ data = request.get_json(silent=True) or {}
314
+ question = (data.get("question") or "").strip()
315
+ db_id = (data.get("db_id") or "").strip() or "unknown"
316
+
317
+ if not question:
318
+ return jsonify({"error": "Please enter a question."}), 400
319
+ if len(question) > MAX_QUESTION_LENGTH:
320
+ return jsonify({"error": f"Question is too long (max {MAX_QUESTION_LENGTH} characters)."}), 400
321
+
322
+ try:
323
+ log.info(f"API predict: question='{question}' db_id='{db_id}'")
324
+ sql = predict(question, db_id, schema="unknown")
325
+ return jsonify({"sql": sql})
326
+ except Exception as e:
327
+ log.exception("Prediction failed")
328
+ return jsonify({"error": f"Inference failed: {e}"}), 500
329
+
330
+
331
+ @app.route("/", methods=["GET", "POST"])
332
+ def home():
333
+ question = None
334
+ db_id = None
335
+ schema = None
336
+ sql = None
337
+ error = None
338
+
339
+ if request.method == "POST":
340
+ question = request.form.get("question", "").strip()
341
+ db_id = request.form.get("db_id", "").strip() or "unknown"
342
+ schema = request.form.get("schema", "").strip() or "unknown"
343
+
344
+ if not question:
345
+ error = "Please enter a question."
346
+ elif len(question) > MAX_QUESTION_LENGTH:
347
+ error = f"Question is too long (max {MAX_QUESTION_LENGTH} characters)."
348
+ else:
349
+ log.info(f"Predicting for question='{question}' db_id='{db_id}'")
350
+ sql = predict(question, db_id, schema=schema)
351
+
352
+ return render_template_string(HTML, question=question, db_id=db_id, schema=schema, sql=sql, error=error)
353
+
354
+ if __name__ == "__main__":
355
+ debug = os.getenv("FLASK_DEBUG", "false").lower() == "true"
356
+ app.run(debug=debug)
config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ MODEL_PATH = os.getenv("MODEL_PATH", "models/t5-sql")
4
+ BASE_MODEL = os.getenv("BASE_MODEL", "Salesforce/codet5p-220m")
5
+
6
+ MAX_INPUT_LENGTH = 512
7
+ MAX_OUTPUT_LENGTH = 128
8
+ BATCH_SIZE = 2
9
+ ACCUMULATION_STEPS = 4
10
+ NUM_EPOCHS = 6
11
+ LEARNING_RATE = 1e-4
12
+ WARMUP_RATIO = 0.1
13
+ NUM_BEAMS = 5
14
+ MAX_SCHEMA_LENGTH = 400
15
+
16
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "ryanwang-trt/t5-sql")
17
+
18
+ PROMPT_TEMPLATE = "translate English to SQL [database: {db_id} | tables: {schema}]: {question}"
19
+
20
+ SPIDER_DB_DIR = os.getenv("SPIDER_DB_DIR", "data/database")
21
+
22
+ MAX_QUESTION_LENGTH = 500
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ datasets
4
+ scikit-learn
5
+ accelerate
6
+ flask
7
+ flask-cors
8
+ huggingface_hub
9
+ python-dotenv