Spaces:
Runtime error
Runtime error
File size: 2,428 Bytes
1914b78 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | # src/pipeline/gen_query.py
import logging
from src.template.prompt import generate_message_template, generate_refine_template
from src.utils.config import get_model
import torch
logger = logging.getLogger(__name__)
def _call_llm(llm, messages: list, max_tokens: int = 128) -> str:
response = llm.create_chat_completion(
messages=messages,
max_tokens=max_tokens,
stop=["</s>", "\n\n"],
temperature=0.1,
)
return response["choices"][0]["message"]["content"].strip()
def generate_query(user_query: str, model_schema) -> dict:
llm, _ = get_model()
messages_1 = generate_message_template(user_query, model_schema)
sql_1 = _call_llm(llm, messages_1)
logger.info(f"Stage 1 SQL: {sql_1}")
print(f"🔵 Stage 1: {sql_1}")
messages_2 = generate_refine_template(user_query, model_schema, sql_1, stage=2)
sql_2 = _call_llm(llm, messages_2)
logger.info(f"Stage 2 SQL: {sql_2}")
print(f"🟡 Stage 2: {sql_2}")
messages_3 = generate_refine_template(user_query, model_schema, sql_2, stage=3)
sql_3 = _call_llm(llm, messages_3)
logger.info(f"Stage 3 SQL: {sql_3}")
print(f"🟢 Stage 3 (final): {sql_3}")
return {
"final": sql_3,
"stage_1": sql_1,
"stage_2": sql_2,
"stage_3": sql_3,
}
def generate_query_trans(user_query, model_schema):
MODEL, TOKENIZER = get_model()
device = "cpu" # keep cpu unless GPU available
torch.set_num_threads(4) # tune: try 2–8
messages = generate_message_template(user_query, model_schema)
text = TOKENIZER.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = TOKENIZER(
text,
return_tensors="pt",
padding=False
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
input_length = input_ids.shape[1]
with torch.inference_mode():
outputs = MODEL.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=64,
do_sample=False,
use_cache=True,
pad_token_id=TOKENIZER.eos_token_id,
)
generated_tokens = outputs[0][input_length:]
response = TOKENIZER.decode(
generated_tokens,
skip_special_tokens=True
)
return response.strip()
|