File size: 8,512 Bytes
84c630e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | from abc import ABC, abstractmethod
from typing import Dict, Any, List, Tuple, Optional, Union
class AlpacaTemplate:
DEFAULT_SYSTEM = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
)
@classmethod
def template(
cls,
user_content: str,
system_content: Union[str, None] = None,
response: Union[str, None] = None,
) -> str:
template: str = ""
if system_content:
template += f"{system_content}\n\n"
else:
template += f"{cls.DEFAULT_SYSTEM}\n\n"
template += f"### Instruction:\n{user_content}\n\n### Response:\n"
if response:
template += response
return template
class SftPrompt(ABC):
@classmethod
@abstractmethod
def prompt_user_content(cls, data: Dict[str, Any]) -> str:
pass
@classmethod
@abstractmethod
def prompt_system_content(cls, data: Dict[str, Any]) -> Optional[str]:
pass
@classmethod
@abstractmethod
def prompt_target(cls, sql: str) -> str:
pass
@classmethod
def prompt(cls, data: Dict[str, Any]) -> Tuple[str, Optional[str], str]:
user_content = cls.prompt_user_content(data)
system_content = cls.prompt_system_content(data)
if "spark_sql" in data:
target = cls.prompt_target(data["spark_sql"])
else:
target = None
return user_content, system_content, target
class SQLGeneratePrompt(SftPrompt):
@classmethod
def split_res(
cls, res: Union[str, List[str], None]
) -> Union[str, List[Union[str, None]], None]:
if res == None:
return res
if isinstance(res, list):
for i in range(len(res)):
res[i] = cls.split_res(res[i])
return res
else:
res = res.strip()
if not (res.startswith("```sql") and res.endswith("```")):
return None
return res[6:-3].strip()
@classmethod
def compose_extra_task_desc(
cls,
hint: Union[str, None],
related_question_sqls: Union[List[Dict[str, str]], None],
) -> str:
if not hint and not related_question_sqls:
return ""
elif hint and not related_question_sqls:
return (
"I provide generation hint, you can refer to it to help you generate.\n"
)
elif not hint and related_question_sqls:
return "I provide other user queries related to the user query with the corresponding Spark SQL queries, you can refer to them to help you generate.\n"
else:
return "I provide generation hint and other user queries related to the user query with the corresponding Spark SQL queries, you can refer to them to help you generate.\n"
@classmethod
def compose_hint_content(cls, hint: Union[str, None]) -> str:
if hint == None:
return ""
if len(hint) == 0:
return ""
hint = hint.strip()
content = (
"[BEGIN OF GENERATION HINT]\n" f"{hint}\n" "[END OF GENERATION HINT]\n" "\n"
)
return content
@classmethod
def compose_related_question_sqls_content(
cls, related_question_sqls: Union[List[Dict[str, str]], None]
) -> str:
if related_question_sqls == None:
return ""
if len(related_question_sqls) == 0:
return ""
content = "[BEGIN OF RELATED QUERIES]\n"
for i, question_sql in enumerate(related_question_sqls):
question, sql = question_sql["question"], question_sql["spark_sql"]
question = question.strip()
question = question.replace("\n", " ")
sql = sql.strip()
sub_content = (
f"# Related Query {i+1}\n"
"## User Query\n"
f"`{question}`\n"
"## Spark SQL Query\n"
"```sql\n"
f"{sql}\n"
"```\n"
"\n"
)
content = content + sub_content
content = content.strip() + "\n"
content += "[END OF RELATED QUERIES]\n\n"
return content
@classmethod
def extract_table_schema(cls, user_content: str) -> str:
start_idx = user_content.find("[BEGIN OF TABLE SCHEMAS]") + len(
"[BEGIN OF TABLE SCHEMAS]"
)
end_idx = user_content.find("[END OF TABLE SCHEMAS]")
return user_content[start_idx:end_idx].strip()
@classmethod
def extract_user_query(cls, user_content: str) -> str:
start_idx = user_content.find("[BEGIN OF QUERY]\nUser Query: ") + len(
"[BEGIN OF QUERY]\nUser Query: "
)
end_idx = user_content.find("[END OF QUERY]")
return user_content[start_idx:end_idx].strip()
@classmethod
def extract_hint(cls, user_content: str) -> Union[str, None]:
start_idx = user_content.find("[BEGIN OF GENERATION HINT]") + len(
"[BEGIN OF GENERATION HINT]"
)
end_idx = user_content.find("[END OF GENERATION HINT]")
if end_idx == -1:
return None
return user_content[start_idx:end_idx].strip()
@classmethod
def extract_related_question_sqls(
cls, user_content: str
) -> Union[List[Dict[str, str]], None]:
start_idx = user_content.find("[BEGIN OF RELATED QUERIES]") + len(
"[BEGIN OF RELATED QUERIES]"
)
end_idx = user_content.find("[END OF RELATED QUERIES]")
if end_idx == -1:
return None
related_question_sqls = user_content[start_idx:end_idx].strip().split("\n\n")
res = []
for question_sql in related_question_sqls:
question_start_idx = question_sql.find("`") + 1
question_end_idx = question_sql.find("`\n##")
question = question_sql[question_start_idx:question_end_idx]
sql_start_idx = question_sql.find("```sql\n") + 7
sql_end_idx = -3
sql = question_sql[sql_start_idx:sql_end_idx]
res.append({"question": question, "spark_sql": sql})
return res
@classmethod
def prompt_user_content(cls, data: Dict[str, Any]) -> str:
question, schema = data["question"].strip(), data["schema"].strip()
hint, related_question_sqls = (
data["hint"],
data["related_question_sqls"],
)
extra_task_desc = cls.compose_extra_task_desc(hint, related_question_sqls)
hint_content = cls.compose_hint_content(hint)
related_question_sqls_content = cls.compose_related_question_sqls_content(
related_question_sqls
)
user_content = (
"[BEGIN OF TASK INSTRUCTION]\n"
"You are an expert in composing Spark SQL queries. You are given a user query and a set of table schemas.\n"
"Based on the user query, you need to generate one Spark SQL query to achieve the purpose.\n"
f"{extra_task_desc}"
"[END OF TASK INSTRUCTION]\n"
"\n"
"[BEGIN OF TABLE SCHEMAS]\n"
f"{schema}\n"
"[END OF TABLE SCHEMAS]\n"
"\n"
f"{hint_content}"
f"{related_question_sqls_content}"
"[BEGIN OF FORMAT INSTRUCTION]\n"
"The output MUST strictly adhere to the following format, and NO other text MUST be included.\n"
"```sql\n"
"your output Spark SQL query\n"
"```\n"
"[END OF FORMAT INSTRUCTION]\n"
"\n"
"[BEGIN OF QUERY]\n"
f"User Query: {question}\n"
"[END OF QUERY]\n"
)
user_content = AlpacaTemplate.template(user_content)
return user_content
@classmethod
def prompt_system_content(cls, data: Dict[str, Any]) -> Union[str, None]:
return None
@classmethod
def prompt_target(cls, sql: Union[str, List[str]]) -> Union[str, List[str]]:
def map_func(tmp_sql: str) -> str:
tmp_sql = tmp_sql.strip()
tmp_sql = tmp_sql.strip(";")
tmp_sql = tmp_sql.strip()
tmp_sql += ";"
return "```sql\n" f"{tmp_sql}\n" "```"
if isinstance(sql, str):
target = map_func(sql)
else:
target = [map_func(_sql) for _sql in sql]
return target
|