File size: 2,897 Bytes
cc2ed2f
8871df9
cc2ed2f
 
 
 
 
8871df9
cc2ed2f
8871df9
 
cc2ed2f
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
8871df9
 
 
 
 
 
 
 
cc2ed2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Тесты на PromptBuilder.

Покрывают как базовое формирование chat-template, так и опциональную
интеграцию BusinessVocabulary в системное сообщение (раздел 3.6 ВКР).
"""

from src.business.vocabulary import BusinessVocabulary
from src.data.prompt import (
    BASE_SYSTEM_PROMPT,
    SYSTEM_PROMPT,
    build_chat_messages,
    build_system_message,
    build_training_example,
    build_user_message,
)


def test_user_message_contains_parts():
    msg = build_user_message("CREATE TABLE t (id INT);", "Покажи всё")
    assert "Schema:" in msg
    assert "Question:" in msg
    assert "SQL:" in msg
    assert "CREATE TABLE" in msg
    assert "Покажи всё" in msg


def test_chat_messages_have_system_and_user():
    msgs = build_chat_messages("schema", "question")
    assert len(msgs) == 2
    assert msgs[0]["role"] == "system"
    assert msgs[0]["content"] == BASE_SYSTEM_PROMPT
    assert msgs[1]["role"] == "user"


def test_training_example_has_assistant():
    msgs = build_training_example("schema", "question", "SELECT 1")
    assert len(msgs) == 3
    assert msgs[2]["role"] == "assistant"
    assert msgs[2]["content"] == "SELECT 1"


def test_legacy_system_prompt_alias():
    assert SYSTEM_PROMPT == BASE_SYSTEM_PROMPT


def test_system_message_without_vocabulary():
    assert build_system_message(None) == BASE_SYSTEM_PROMPT


def test_system_message_with_empty_vocabulary():
    vocab = BusinessVocabulary.empty()
    assert build_system_message(vocab) == BASE_SYSTEM_PROMPT


def test_system_message_with_terms():
    vocab = BusinessVocabulary(
        company="ООО Ромашка",
        terms={"выручка": "SUM(orders.amount) WHERE orders.status = 'paid'"},
    )
    msg = build_system_message(vocab)
    assert msg.startswith(BASE_SYSTEM_PROMPT)
    assert "ООО Ромашка" in msg
    assert "выручка" in msg
    assert "SUM(orders.amount)" in msg


def test_chat_messages_with_vocabulary_keeps_user_clean():
    vocab = BusinessVocabulary(
        terms={"выручка": "SUM(amount) WHERE status='paid'"},
    )
    msgs = build_chat_messages("schema", "Какая выручка?", vocabulary=vocab)
    assert msgs[0]["role"] == "system"
    assert "SUM(amount)" in msgs[0]["content"]
    assert msgs[1]["role"] == "user"
    assert "SUM(amount)" not in msgs[1]["content"]
    assert "Какая выручка?" in msgs[1]["content"]


def test_training_example_with_vocabulary():
    vocab = BusinessVocabulary(terms={"топ": "ORDER BY x DESC LIMIT 10"})
    msgs = build_training_example(
        "schema", "Топ клиентов", "SELECT 1", vocabulary=vocab
    )
    assert len(msgs) == 3
    assert msgs[0]["role"] == "system"
    assert "ORDER BY x DESC" in msgs[0]["content"]
    assert msgs[2]["role"] == "assistant"