File size: 2,428 Bytes
1914b78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# src/pipeline/gen_query.py

import logging
from src.template.prompt import generate_message_template, generate_refine_template
from src.utils.config import get_model
import torch 

logger = logging.getLogger(__name__)

def _call_llm(llm, messages: list, max_tokens: int = 128) -> str:
    response = llm.create_chat_completion(
        messages=messages,
        max_tokens=max_tokens,
        stop=["</s>", "\n\n"],
        temperature=0.1,
    )
    return response["choices"][0]["message"]["content"].strip()


def generate_query(user_query: str, model_schema) -> dict:
    llm, _ = get_model()

    messages_1 = generate_message_template(user_query, model_schema)
    sql_1 = _call_llm(llm, messages_1)
    logger.info(f"Stage 1 SQL: {sql_1}")
    print(f"🔵 Stage 1: {sql_1}")

    messages_2 = generate_refine_template(user_query, model_schema, sql_1, stage=2)
    sql_2 = _call_llm(llm, messages_2)
    logger.info(f"Stage 2 SQL: {sql_2}")
    print(f"🟡 Stage 2: {sql_2}")

   
    messages_3 = generate_refine_template(user_query, model_schema, sql_2, stage=3)
    sql_3 = _call_llm(llm, messages_3)
    logger.info(f"Stage 3 SQL: {sql_3}")
    print(f"🟢 Stage 3 (final): {sql_3}")

    return {
        "final": sql_3,
        "stage_1": sql_1,
        "stage_2": sql_2,
        "stage_3": sql_3,
    }





def generate_query_trans(user_query, model_schema):
    MODEL, TOKENIZER = get_model()

    device = "cpu"   # keep cpu unless GPU available

    torch.set_num_threads(4)   # tune: try 2–8

    messages = generate_message_template(user_query, model_schema)

    text = TOKENIZER.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = TOKENIZER(
        text,
        return_tensors="pt",
        padding=False
    )

    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    input_length = input_ids.shape[1]  

    with torch.inference_mode():
        outputs = MODEL.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=64,
            do_sample=False,
            use_cache=True,
            pad_token_id=TOKENIZER.eos_token_id,
        )

    generated_tokens = outputs[0][input_length:]

    response = TOKENIZER.decode(
        generated_tokens,
        skip_special_tokens=True
    )

    return response.strip()