File size: 4,871 Bytes
ef5cf7c
c066f5b
02c2107
df3c0b2
 
ef5cf7c
df3c0b2
 
 
 
 
 
 
 
 
269f3f7
df3c0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef5cf7c
c066f5b
df3c0b2
 
 
02c2107
c066f5b
 
02c2107
c066f5b
02c2107
c066f5b
 
df3c0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
c066f5b
 
 
df3c0b2
 
 
c066f5b
df3c0b2
 
 
 
 
 
 
 
 
 
 
 
 
c066f5b
df3c0b2
c066f5b
 
 
 
 
df3c0b2
 
c066f5b
 
df3c0b2
 
 
c066f5b
 
02c2107
df3c0b2
 
 
f2c30d1
02c2107
df3c0b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
from llama_cpp import Llama
import os
from huggingface_hub import snapshot_download
import logging

# Настройка логирования для отладки в HF Spaces
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ID модели на HF
model_repo = "eliza555beth2002/DeepSeek-R1-Distill-Text2SQL-OneEpoch-GGUF-q4"

# Локальная директория для модели (HF Spaces монтирует /app)
model_dir = "./model"
model_filename = "unsloth.Q4_K_M.gguf"  # Точное имя файла из репозитория

# Скачивание модели (если не скачана)
if not os.path.exists(os.path.join(model_dir, model_filename)):
    logger.info("Скачиваем модель... Это может занять 5–10 мин.")
    try:
        snapshot_download(
            repo_id=model_repo,
            local_dir=model_dir,
            local_dir_use_symlinks=False,  # Избежать проблем с симлинками в Spaces
            allow_patterns=[model_filename]  # Скачать только GGUF
        )
        logger.info("Модель скачана успешно.")
    except Exception as e:
        logger.error(f"Ошибка скачивания: {e}")
        raise

# Полный путь к модели
model_path = os.path.join(model_dir, model_filename)

# Загрузка модели
logger.info("Загружаем Llama...")
try:
    llm = Llama(
        model_path=model_path,
        n_ctx=2048,  # Контекст для схемы БД
        n_gpu_layers=-1,  # Использовать GPU, если доступно (в Spaces с T4)
        verbose=False,
        seed=42  # Для воспроизводимости
    )
    logger.info("Модель загружена успешно.")
except Exception as e:
    logger.error(f"Ошибка загрузки модели: {e}")
    raise

def generate_sql(natural_query, db_schema):
    # Шаблон промпта для Text2SQL (улучшенный для точности)
    prompt = f"""You are a SQL expert. Given the database schema below, generate a valid SQL query for the natural language question. 
    Use only the provided schema. Output only the SQL query, no explanations.

    Database Schema:
    {db_schema}

    Question: {natural_query}

    SQL Query:"""
    
    try:
        output = llm(
            prompt,
            max_tokens=150,
            temperature=0.1,  # Низкая для точности SQL
            stop=[";", "\n\n", "Question:"],  # Стоп на конце запроса
            echo=False
        )
        sql_query = output['choices'][0]['text'].strip()
        if not sql_query.endswith(';'):
            sql_query += ';'
        return sql_query
    except Exception as e:
        return f"Ошибка генерации: {e}"

# Gradio интерфейс
with gr.Blocks(title="Text2SQL Demo") as demo:
    gr.Markdown("# Text2SQL: Преобразование текста в SQL\n\nМодель: DeepSeek-R1-Distill (8B, q4 GGUF)")
    gr.Markdown("**Статус:** Модель загружена. Введите вопрос и схему для генерации SQL.")
    
    with gr.Row():
        natural_input = gr.Textbox(
            label="Вопрос на естественном языке", 
            placeholder="Найди всех пользователей старше 30 лет из Москвы",
            lines=2
        )
        schema_input = gr.Textbox(
            label="Схема БД (CREATE TABLE)", 
            placeholder="CREATE TABLE users (id INT, name VARCHAR(100), age INT, city VARCHAR(50));",
            lines=3
        )
    
    output = gr.Textbox(label="Сгенерированный SQL", lines=3)
    submit_btn = gr.Button("Генерировать SQL", variant="primary")
    
    # Связь событий
    submit_btn.click(
        fn=generate_sql,
        inputs=[natural_input, schema_input],
        outputs=output
    )
    
    # Примеры
    gr.Examples(
        examples=[
            ["Сколько продуктов дороже 100?", "CREATE TABLE products (id INT, name VARCHAR(100), price DECIMAL(10,2));"],
            ["Пользователи из Москвы старше 25", "CREATE TABLE users (id INT, name VARCHAR(100), age INT, city VARCHAR(50));"],
            ["Топ 5 самых дорогих заказов", "CREATE TABLE orders (id INT, customer_id INT, total DECIMAL(10,2));"]
        ],
        inputs=[natural_input, schema_input]
    )
    
    # Футер с логами
    gr.Markdown("**Логи:** Проверьте консоль для деталей.")

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)  # Для HF Spaces