# src/pipeline/gen_query.py import logging from src.template.prompt import generate_message_template, generate_refine_template from src.utils.config import get_model import torch logger = logging.getLogger(__name__) def _call_llm(llm, messages: list, max_tokens: int = 128) -> str: response = llm.create_chat_completion( messages=messages, max_tokens=max_tokens, stop=["", "\n\n"], temperature=0.1, ) return response["choices"][0]["message"]["content"].strip() def generate_query(user_query: str, model_schema) -> dict: llm, _ = get_model() messages_1 = generate_message_template(user_query, model_schema) sql_1 = _call_llm(llm, messages_1) logger.info(f"Stage 1 SQL: {sql_1}") print(f"🔵 Stage 1: {sql_1}") messages_2 = generate_refine_template(user_query, model_schema, sql_1, stage=2) sql_2 = _call_llm(llm, messages_2) logger.info(f"Stage 2 SQL: {sql_2}") print(f"🟡 Stage 2: {sql_2}") messages_3 = generate_refine_template(user_query, model_schema, sql_2, stage=3) sql_3 = _call_llm(llm, messages_3) logger.info(f"Stage 3 SQL: {sql_3}") print(f"🟢 Stage 3 (final): {sql_3}") return { "final": sql_3, "stage_1": sql_1, "stage_2": sql_2, "stage_3": sql_3, } def generate_query_trans(user_query, model_schema): MODEL, TOKENIZER = get_model() device = "cpu" # keep cpu unless GPU available torch.set_num_threads(4) # tune: try 2–8 messages = generate_message_template(user_query, model_schema) text = TOKENIZER.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = TOKENIZER( text, return_tensors="pt", padding=False ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) input_length = input_ids.shape[1] with torch.inference_mode(): outputs = MODEL.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=64, do_sample=False, use_cache=True, pad_token_id=TOKENIZER.eos_token_id, ) generated_tokens = outputs[0][input_length:] response = TOKENIZER.decode( generated_tokens, skip_special_tokens=True ) return response.strip()