AI-powered-SQL / src /pipeline /gen_query.py
github-actions
Auto deploy from GitHub Actions
1914b78
# src/pipeline/gen_query.py
import logging
from src.template.prompt import generate_message_template, generate_refine_template
from src.utils.config import get_model
import torch
logger = logging.getLogger(__name__)
def _call_llm(llm, messages: list, max_tokens: int = 128) -> str:
response = llm.create_chat_completion(
messages=messages,
max_tokens=max_tokens,
stop=["</s>", "\n\n"],
temperature=0.1,
)
return response["choices"][0]["message"]["content"].strip()
def generate_query(user_query: str, model_schema) -> dict:
llm, _ = get_model()
messages_1 = generate_message_template(user_query, model_schema)
sql_1 = _call_llm(llm, messages_1)
logger.info(f"Stage 1 SQL: {sql_1}")
print(f"πŸ”΅ Stage 1: {sql_1}")
messages_2 = generate_refine_template(user_query, model_schema, sql_1, stage=2)
sql_2 = _call_llm(llm, messages_2)
logger.info(f"Stage 2 SQL: {sql_2}")
print(f"🟑 Stage 2: {sql_2}")
messages_3 = generate_refine_template(user_query, model_schema, sql_2, stage=3)
sql_3 = _call_llm(llm, messages_3)
logger.info(f"Stage 3 SQL: {sql_3}")
print(f"🟒 Stage 3 (final): {sql_3}")
return {
"final": sql_3,
"stage_1": sql_1,
"stage_2": sql_2,
"stage_3": sql_3,
}
def generate_query_trans(user_query, model_schema):
MODEL, TOKENIZER = get_model()
device = "cpu" # keep cpu unless GPU available
torch.set_num_threads(4) # tune: try 2–8
messages = generate_message_template(user_query, model_schema)
text = TOKENIZER.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = TOKENIZER(
text,
return_tensors="pt",
padding=False
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
input_length = input_ids.shape[1]
with torch.inference_mode():
outputs = MODEL.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=64,
do_sample=False,
use_cache=True,
pad_token_id=TOKENIZER.eos_token_id,
)
generated_tokens = outputs[0][input_length:]
response = TOKENIZER.decode(
generated_tokens,
skip_special_tokens=True
)
return response.strip()