pp542-0965 commited on
Commit ·
c52a50c
1
Parent(s): d3cf308
Add gradio app
Browse files
app.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from peft import PeftModel
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_model_tokenizer():
|
| 12 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560)
|
| 13 |
+
model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False)
|
| 14 |
+
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560)
|
| 16 |
+
|
| 17 |
+
return model, tokenizer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
model, tokenizer = load_model_tokenizer()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def create_prompt(schemas, question):
|
| 24 |
+
prompt = [
|
| 25 |
+
{
|
| 26 |
+
'role': 'system',
|
| 27 |
+
'content': """\
|
| 28 |
+
You are an expert SQL Query Writer.
|
| 29 |
+
Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer.
|
| 30 |
+
Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas.
|
| 31 |
+
|
| 32 |
+
Remember that you should place all your reasoning between <reason> and </reason> tags.
|
| 33 |
+
Also, you should provide your solution between <answer> and </answer> tags.
|
| 34 |
+
|
| 35 |
+
An example generation is as follows:
|
| 36 |
+
<reason>
|
| 37 |
+
This is a sample reasoning that solves the question based on the schema.
|
| 38 |
+
</reason>
|
| 39 |
+
<answer>
|
| 40 |
+
SELECT
|
| 41 |
+
COLUMN
|
| 42 |
+
FROM TABLE_NAME
|
| 43 |
+
WHERE
|
| 44 |
+
CONDITION
|
| 45 |
+
</answer>"""
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
'role': 'user',
|
| 49 |
+
'content': f"""\
|
| 50 |
+
SCHEMAS:
|
| 51 |
+
---------------
|
| 52 |
+
|
| 53 |
+
{schemas}
|
| 54 |
+
|
| 55 |
+
---------------
|
| 56 |
+
|
| 57 |
+
QUESTION: "{question}"\
|
| 58 |
+
"""
|
| 59 |
+
}
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
return prompt
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def extract_answer(gen_output):
|
| 66 |
+
answer_start_token = "<answer>"
|
| 67 |
+
answer_end_token = "</answer>"
|
| 68 |
+
answer_match_format = re.compile(rf"{answer_start_token}(.+?){answer_end_token}", flags = re.MULTILINE | re.DOTALL | re.IGNORECASE)
|
| 69 |
+
|
| 70 |
+
answer_match = answer_match_format.search(gen_output)
|
| 71 |
+
|
| 72 |
+
final_answer = None
|
| 73 |
+
|
| 74 |
+
if answer_match is not None:
|
| 75 |
+
final_answer = answer_match.group(1)
|
| 76 |
+
|
| 77 |
+
return final_answer
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def response(user_schemas, user_question):
|
| 81 |
+
user_prompt = create_prompt(user_schemas, user_question)
|
| 82 |
+
|
| 83 |
+
inputs = tokenizer.apply_chat_template(user_prompt,
|
| 84 |
+
tokenize=True,
|
| 85 |
+
add_generation_prompt=True,
|
| 86 |
+
return_dict=True,
|
| 87 |
+
return_tensors="pt")
|
| 88 |
+
|
| 89 |
+
with torch.inference_mode():
|
| 90 |
+
outputs = model.generate(**inputs, max_new_tokens=1024)
|
| 91 |
+
|
| 92 |
+
outputs = tokenizer.batch_decode(outputs)
|
| 93 |
+
output = outputs[0].split("<|im_start|>assistant")[-1]
|
| 94 |
+
|
| 95 |
+
final_answer = extract_answer(output)
|
| 96 |
+
|
| 97 |
+
return output + "\n\n" + "="*20 + "\n\nFinal Answer: \n" + final_answer
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
desc="""
|
| 101 |
+
Please use the "Table Schemas" field to provide the required schemas to to generate the SQL Query for - separated by new lines.
|
| 102 |
+
Eg. CREATE TABLE demographic (
|
| 103 |
+
subject_id text,
|
| 104 |
+
admission_type text,
|
| 105 |
+
hadm_id text)
|
| 106 |
+
|
| 107 |
+
CREATE TABLE diagnoses (
|
| 108 |
+
subject_id text,
|
| 109 |
+
hadm_id text)
|
| 110 |
+
|
| 111 |
+
Finally, use the "Question" field to provide the relevant question to be answered based on the provided schemas.
|
| 112 |
+
Eg. How many patients whose admission type is emergency.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
demo = gr.Interface(
|
| 116 |
+
fn=response,
|
| 117 |
+
inputs=[gr.Textbox(label="Table Schemas",
|
| 118 |
+
placeholder="Expected to have CREATE TABLE statements with datatypes separated by new lines"),
|
| 119 |
+
gr.Textbox(label="Question",
|
| 120 |
+
placeholder="Eg. How many patients whose admission type is emergency")
|
| 121 |
+
],
|
| 122 |
+
outputs=gr.Textbox(label="Generated SQL Query with reasoning"),
|
| 123 |
+
title="SQL Query Generator trained with GRPO to elicit reasoning",
|
| 124 |
+
description=desc
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
demo.launch()
|