File size: 3,345 Bytes
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""FastAPI приложение.

Запуск:
    uvicorn src.api.main:app --reload
    # Swagger UI: http://127.0.0.1:8000/docs
"""

from __future__ import annotations

import sqlite3

from fastapi import Depends, FastAPI, HTTPException
from fastapi.concurrency import run_in_threadpool

from src.api.dependencies import get_engine, get_schema_retriever, lifespan
from src.api.schemas import (
    DatabaseInfo,
    ExecutionResult,
    GenerateRequest,
    GenerateResponse,
    HealthResponse,
)
from src.config import settings
from src.data.schema import SchemaRetriever
from src.models.inference import InferenceEngine
from src.models.postprocess import is_valid_sql

app = FastAPI(
    title="ru2sql",
    description="Преобразование вопросов на русском в SQL-запросы",
    version="0.1.0",
    lifespan=lifespan,
)


@app.get("/health", response_model=HealthResponse)
def health(engine: InferenceEngine = Depends(get_engine)):
    return HealthResponse(
        status="ok",
        model_loaded=engine._loaded,
        base_model=engine.base_model_name,
    )


@app.get("/databases", response_model=list[DatabaseInfo])
def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
    out: list[DatabaseInfo] = []
    for db_id in retriever.list_databases():
        try:
            tables = [t.name for t in retriever.get_tables(db_id, n_sample_rows=0)]
            out.append(DatabaseInfo(db_id=db_id, tables=tables))
        except FileNotFoundError:
            continue
    return out


@app.post("/generate-sql", response_model=GenerateResponse)
async def generate_sql(
    req: GenerateRequest,
    engine: InferenceEngine = Depends(get_engine),
    retriever: SchemaRetriever = Depends(get_schema_retriever),
):
    try:
        schema_text = retriever.render_schema(req.db_id)
    except FileNotFoundError as e:
        raise HTTPException(status_code=404, detail=str(e)) from e

    # Inference синхронный и тяжёлый — выносим в threadpool
    result = await run_in_threadpool(engine.generate, schema_text, req.question)

    valid = is_valid_sql(result.sql)
    response = GenerateResponse(
        sql=result.sql,
        raw_output=result.raw_output,
        is_valid_sql=valid,
    )

    if req.execute and valid:
        try:
            response.execution = await run_in_threadpool(
                _execute_sql, req.db_id, result.sql, retriever
            )
        except sqlite3.Error as e:
            response.error = f"SQL execution error: {e}"

    return response


def _execute_sql(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionResult:
    db_path = retriever.db_path(db_id)
    conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
    try:
        conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
        cur = conn.cursor()
        cur.execute(sql)
        rows = cur.fetchmany(100)
        cols = [d[0] for d in cur.description] if cur.description else []
        return ExecutionResult(
            columns=cols,
            rows=[list(r) for r in rows],
            row_count=len(rows),
        )
    finally:
        conn.close()


if __name__ == "__main__":
    import uvicorn

    uvicorn.run("src.api.main:app", host=settings.api_host, port=settings.api_port, reload=True)