rain-SQLCoder / utils /prompt.py
suan-chang's picture
update README.md
84c630e
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Tuple, Optional, Union
class AlpacaTemplate:
DEFAULT_SYSTEM = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
)
@classmethod
def template(
cls,
user_content: str,
system_content: Union[str, None] = None,
response: Union[str, None] = None,
) -> str:
template: str = ""
if system_content:
template += f"{system_content}\n\n"
else:
template += f"{cls.DEFAULT_SYSTEM}\n\n"
template += f"### Instruction:\n{user_content}\n\n### Response:\n"
if response:
template += response
return template
class SftPrompt(ABC):
@classmethod
@abstractmethod
def prompt_user_content(cls, data: Dict[str, Any]) -> str:
pass
@classmethod
@abstractmethod
def prompt_system_content(cls, data: Dict[str, Any]) -> Optional[str]:
pass
@classmethod
@abstractmethod
def prompt_target(cls, sql: str) -> str:
pass
@classmethod
def prompt(cls, data: Dict[str, Any]) -> Tuple[str, Optional[str], str]:
user_content = cls.prompt_user_content(data)
system_content = cls.prompt_system_content(data)
if "spark_sql" in data:
target = cls.prompt_target(data["spark_sql"])
else:
target = None
return user_content, system_content, target
class SQLGeneratePrompt(SftPrompt):
@classmethod
def split_res(
cls, res: Union[str, List[str], None]
) -> Union[str, List[Union[str, None]], None]:
if res == None:
return res
if isinstance(res, list):
for i in range(len(res)):
res[i] = cls.split_res(res[i])
return res
else:
res = res.strip()
if not (res.startswith("```sql") and res.endswith("```")):
return None
return res[6:-3].strip()
@classmethod
def compose_extra_task_desc(
cls,
hint: Union[str, None],
related_question_sqls: Union[List[Dict[str, str]], None],
) -> str:
if not hint and not related_question_sqls:
return ""
elif hint and not related_question_sqls:
return (
"I provide generation hint, you can refer to it to help you generate.\n"
)
elif not hint and related_question_sqls:
return "I provide other user queries related to the user query with the corresponding Spark SQL queries, you can refer to them to help you generate.\n"
else:
return "I provide generation hint and other user queries related to the user query with the corresponding Spark SQL queries, you can refer to them to help you generate.\n"
@classmethod
def compose_hint_content(cls, hint: Union[str, None]) -> str:
if hint == None:
return ""
if len(hint) == 0:
return ""
hint = hint.strip()
content = (
"[BEGIN OF GENERATION HINT]\n" f"{hint}\n" "[END OF GENERATION HINT]\n" "\n"
)
return content
@classmethod
def compose_related_question_sqls_content(
cls, related_question_sqls: Union[List[Dict[str, str]], None]
) -> str:
if related_question_sqls == None:
return ""
if len(related_question_sqls) == 0:
return ""
content = "[BEGIN OF RELATED QUERIES]\n"
for i, question_sql in enumerate(related_question_sqls):
question, sql = question_sql["question"], question_sql["spark_sql"]
question = question.strip()
question = question.replace("\n", " ")
sql = sql.strip()
sub_content = (
f"# Related Query {i+1}\n"
"## User Query\n"
f"`{question}`\n"
"## Spark SQL Query\n"
"```sql\n"
f"{sql}\n"
"```\n"
"\n"
)
content = content + sub_content
content = content.strip() + "\n"
content += "[END OF RELATED QUERIES]\n\n"
return content
@classmethod
def extract_table_schema(cls, user_content: str) -> str:
start_idx = user_content.find("[BEGIN OF TABLE SCHEMAS]") + len(
"[BEGIN OF TABLE SCHEMAS]"
)
end_idx = user_content.find("[END OF TABLE SCHEMAS]")
return user_content[start_idx:end_idx].strip()
@classmethod
def extract_user_query(cls, user_content: str) -> str:
start_idx = user_content.find("[BEGIN OF QUERY]\nUser Query: ") + len(
"[BEGIN OF QUERY]\nUser Query: "
)
end_idx = user_content.find("[END OF QUERY]")
return user_content[start_idx:end_idx].strip()
@classmethod
def extract_hint(cls, user_content: str) -> Union[str, None]:
start_idx = user_content.find("[BEGIN OF GENERATION HINT]") + len(
"[BEGIN OF GENERATION HINT]"
)
end_idx = user_content.find("[END OF GENERATION HINT]")
if end_idx == -1:
return None
return user_content[start_idx:end_idx].strip()
@classmethod
def extract_related_question_sqls(
cls, user_content: str
) -> Union[List[Dict[str, str]], None]:
start_idx = user_content.find("[BEGIN OF RELATED QUERIES]") + len(
"[BEGIN OF RELATED QUERIES]"
)
end_idx = user_content.find("[END OF RELATED QUERIES]")
if end_idx == -1:
return None
related_question_sqls = user_content[start_idx:end_idx].strip().split("\n\n")
res = []
for question_sql in related_question_sqls:
question_start_idx = question_sql.find("`") + 1
question_end_idx = question_sql.find("`\n##")
question = question_sql[question_start_idx:question_end_idx]
sql_start_idx = question_sql.find("```sql\n") + 7
sql_end_idx = -3
sql = question_sql[sql_start_idx:sql_end_idx]
res.append({"question": question, "spark_sql": sql})
return res
@classmethod
def prompt_user_content(cls, data: Dict[str, Any]) -> str:
question, schema = data["question"].strip(), data["schema"].strip()
hint, related_question_sqls = (
data["hint"],
data["related_question_sqls"],
)
extra_task_desc = cls.compose_extra_task_desc(hint, related_question_sqls)
hint_content = cls.compose_hint_content(hint)
related_question_sqls_content = cls.compose_related_question_sqls_content(
related_question_sqls
)
user_content = (
"[BEGIN OF TASK INSTRUCTION]\n"
"You are an expert in composing Spark SQL queries. You are given a user query and a set of table schemas.\n"
"Based on the user query, you need to generate one Spark SQL query to achieve the purpose.\n"
f"{extra_task_desc}"
"[END OF TASK INSTRUCTION]\n"
"\n"
"[BEGIN OF TABLE SCHEMAS]\n"
f"{schema}\n"
"[END OF TABLE SCHEMAS]\n"
"\n"
f"{hint_content}"
f"{related_question_sqls_content}"
"[BEGIN OF FORMAT INSTRUCTION]\n"
"The output MUST strictly adhere to the following format, and NO other text MUST be included.\n"
"```sql\n"
"your output Spark SQL query\n"
"```\n"
"[END OF FORMAT INSTRUCTION]\n"
"\n"
"[BEGIN OF QUERY]\n"
f"User Query: {question}\n"
"[END OF QUERY]\n"
)
user_content = AlpacaTemplate.template(user_content)
return user_content
@classmethod
def prompt_system_content(cls, data: Dict[str, Any]) -> Union[str, None]:
return None
@classmethod
def prompt_target(cls, sql: Union[str, List[str]]) -> Union[str, List[str]]:
def map_func(tmp_sql: str) -> str:
tmp_sql = tmp_sql.strip()
tmp_sql = tmp_sql.strip(";")
tmp_sql = tmp_sql.strip()
tmp_sql += ";"
return "```sql\n" f"{tmp_sql}\n" "```"
if isinstance(sql, str):
target = map_func(sql)
else:
target = [map_func(_sql) for _sql in sql]
return target