File size: 8,512 Bytes
84c630e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Tuple, Optional, Union


class AlpacaTemplate:
    DEFAULT_SYSTEM = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request."
    )

    @classmethod
    def template(
        cls,
        user_content: str,
        system_content: Union[str, None] = None,
        response: Union[str, None] = None,
    ) -> str:
        template: str = ""

        if system_content:
            template += f"{system_content}\n\n"
        else:
            template += f"{cls.DEFAULT_SYSTEM}\n\n"

        template += f"### Instruction:\n{user_content}\n\n### Response:\n"
        if response:
            template += response

        return template


class SftPrompt(ABC):
    @classmethod
    @abstractmethod
    def prompt_user_content(cls, data: Dict[str, Any]) -> str:
        pass

    @classmethod
    @abstractmethod
    def prompt_system_content(cls, data: Dict[str, Any]) -> Optional[str]:
        pass

    @classmethod
    @abstractmethod
    def prompt_target(cls, sql: str) -> str:
        pass

    @classmethod
    def prompt(cls, data: Dict[str, Any]) -> Tuple[str, Optional[str], str]:
        user_content = cls.prompt_user_content(data)
        system_content = cls.prompt_system_content(data)
        if "spark_sql" in data:
            target = cls.prompt_target(data["spark_sql"])
        else:
            target = None
        return user_content, system_content, target


class SQLGeneratePrompt(SftPrompt):
    @classmethod
    def split_res(
        cls, res: Union[str, List[str], None]
    ) -> Union[str, List[Union[str, None]], None]:
        if res == None:
            return res
        if isinstance(res, list):
            for i in range(len(res)):
                res[i] = cls.split_res(res[i])
            return res
        else:
            res = res.strip()
            if not (res.startswith("```sql") and res.endswith("```")):
                return None
            return res[6:-3].strip()

    @classmethod
    def compose_extra_task_desc(
        cls,
        hint: Union[str, None],
        related_question_sqls: Union[List[Dict[str, str]], None],
    ) -> str:
        if not hint and not related_question_sqls:
            return ""
        elif hint and not related_question_sqls:
            return (
                "I provide generation hint, you can refer to it to help you generate.\n"
            )
        elif not hint and related_question_sqls:
            return "I provide other user queries related to the user query with the corresponding Spark SQL queries, you can refer to them to help you generate.\n"
        else:
            return "I provide generation hint and other user queries related to the user query with the corresponding Spark SQL queries, you can refer to them to help you generate.\n"

    @classmethod
    def compose_hint_content(cls, hint: Union[str, None]) -> str:
        if hint == None:
            return ""
        if len(hint) == 0:
            return ""
        hint = hint.strip()
        content = (
            "[BEGIN OF GENERATION HINT]\n" f"{hint}\n" "[END OF GENERATION HINT]\n" "\n"
        )
        return content

    @classmethod
    def compose_related_question_sqls_content(
        cls, related_question_sqls: Union[List[Dict[str, str]], None]
    ) -> str:
        if related_question_sqls == None:
            return ""
        if len(related_question_sqls) == 0:
            return ""
        content = "[BEGIN OF RELATED QUERIES]\n"
        for i, question_sql in enumerate(related_question_sqls):
            question, sql = question_sql["question"], question_sql["spark_sql"]

            question = question.strip()
            question = question.replace("\n", " ")
            sql = sql.strip()

            sub_content = (
                f"# Related Query {i+1}\n"
                "## User Query\n"
                f"`{question}`\n"
                "## Spark SQL Query\n"
                "```sql\n"
                f"{sql}\n"
                "```\n"
                "\n"
            )
            content = content + sub_content
        content = content.strip() + "\n"
        content += "[END OF RELATED QUERIES]\n\n"
        return content

    @classmethod
    def extract_table_schema(cls, user_content: str) -> str:
        start_idx = user_content.find("[BEGIN OF TABLE SCHEMAS]") + len(
            "[BEGIN OF TABLE SCHEMAS]"
        )
        end_idx = user_content.find("[END OF TABLE SCHEMAS]")
        return user_content[start_idx:end_idx].strip()

    @classmethod
    def extract_user_query(cls, user_content: str) -> str:
        start_idx = user_content.find("[BEGIN OF QUERY]\nUser Query: ") + len(
            "[BEGIN OF QUERY]\nUser Query: "
        )
        end_idx = user_content.find("[END OF QUERY]")
        return user_content[start_idx:end_idx].strip()

    @classmethod
    def extract_hint(cls, user_content: str) -> Union[str, None]:
        start_idx = user_content.find("[BEGIN OF GENERATION HINT]") + len(
            "[BEGIN OF GENERATION HINT]"
        )
        end_idx = user_content.find("[END OF GENERATION HINT]")
        if end_idx == -1:
            return None
        return user_content[start_idx:end_idx].strip()

    @classmethod
    def extract_related_question_sqls(
        cls, user_content: str
    ) -> Union[List[Dict[str, str]], None]:
        start_idx = user_content.find("[BEGIN OF RELATED QUERIES]") + len(
            "[BEGIN OF RELATED QUERIES]"
        )
        end_idx = user_content.find("[END OF RELATED QUERIES]")
        if end_idx == -1:
            return None
        related_question_sqls = user_content[start_idx:end_idx].strip().split("\n\n")
        res = []
        for question_sql in related_question_sqls:
            question_start_idx = question_sql.find("`") + 1
            question_end_idx = question_sql.find("`\n##")
            question = question_sql[question_start_idx:question_end_idx]

            sql_start_idx = question_sql.find("```sql\n") + 7
            sql_end_idx = -3
            sql = question_sql[sql_start_idx:sql_end_idx]
            res.append({"question": question, "spark_sql": sql})
        return res

    @classmethod
    def prompt_user_content(cls, data: Dict[str, Any]) -> str:
        question, schema = data["question"].strip(), data["schema"].strip()
        hint, related_question_sqls = (
            data["hint"],
            data["related_question_sqls"],
        )
        extra_task_desc = cls.compose_extra_task_desc(hint, related_question_sqls)
        hint_content = cls.compose_hint_content(hint)
        related_question_sqls_content = cls.compose_related_question_sqls_content(
            related_question_sqls
        )

        user_content = (
            "[BEGIN OF TASK INSTRUCTION]\n"
            "You are an expert in composing Spark SQL queries. You are given a user query and a set of table schemas.\n"
            "Based on the user query, you need to generate one Spark SQL query to achieve the purpose.\n"
            f"{extra_task_desc}"
            "[END OF TASK INSTRUCTION]\n"
            "\n"
            "[BEGIN OF TABLE SCHEMAS]\n"
            f"{schema}\n"
            "[END OF TABLE SCHEMAS]\n"
            "\n"
            f"{hint_content}"
            f"{related_question_sqls_content}"
            "[BEGIN OF FORMAT INSTRUCTION]\n"
            "The output MUST strictly adhere to the following format, and NO other text MUST be included.\n"
            "```sql\n"
            "your output Spark SQL query\n"
            "```\n"
            "[END OF FORMAT INSTRUCTION]\n"
            "\n"
            "[BEGIN OF QUERY]\n"
            f"User Query: {question}\n"
            "[END OF QUERY]\n"
        )

        user_content = AlpacaTemplate.template(user_content)
        return user_content

    @classmethod
    def prompt_system_content(cls, data: Dict[str, Any]) -> Union[str, None]:
        return None

    @classmethod
    def prompt_target(cls, sql: Union[str, List[str]]) -> Union[str, List[str]]:
        def map_func(tmp_sql: str) -> str:
            tmp_sql = tmp_sql.strip()
            tmp_sql = tmp_sql.strip(";")
            tmp_sql = tmp_sql.strip()
            tmp_sql += ";"
            return "```sql\n" f"{tmp_sql}\n" "```"

        if isinstance(sql, str):
            target = map_func(sql)
        else:
            target = [map_func(_sql) for _sql in sql]

        return target