Spaces:
Runtime error
Runtime error
| # src/pipeline/gen_query.py | |
| import logging | |
| from src.template.prompt import generate_message_template, generate_refine_template | |
| from src.utils.config import get_model | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| def _call_llm(llm, messages: list, max_tokens: int = 128) -> str: | |
| response = llm.create_chat_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| stop=["</s>", "\n\n"], | |
| temperature=0.1, | |
| ) | |
| return response["choices"][0]["message"]["content"].strip() | |
| def generate_query(user_query: str, model_schema) -> dict: | |
| llm, _ = get_model() | |
| messages_1 = generate_message_template(user_query, model_schema) | |
| sql_1 = _call_llm(llm, messages_1) | |
| logger.info(f"Stage 1 SQL: {sql_1}") | |
| print(f"π΅ Stage 1: {sql_1}") | |
| messages_2 = generate_refine_template(user_query, model_schema, sql_1, stage=2) | |
| sql_2 = _call_llm(llm, messages_2) | |
| logger.info(f"Stage 2 SQL: {sql_2}") | |
| print(f"π‘ Stage 2: {sql_2}") | |
| messages_3 = generate_refine_template(user_query, model_schema, sql_2, stage=3) | |
| sql_3 = _call_llm(llm, messages_3) | |
| logger.info(f"Stage 3 SQL: {sql_3}") | |
| print(f"π’ Stage 3 (final): {sql_3}") | |
| return { | |
| "final": sql_3, | |
| "stage_1": sql_1, | |
| "stage_2": sql_2, | |
| "stage_3": sql_3, | |
| } | |
| def generate_query_trans(user_query, model_schema): | |
| MODEL, TOKENIZER = get_model() | |
| device = "cpu" # keep cpu unless GPU available | |
| torch.set_num_threads(4) # tune: try 2β8 | |
| messages = generate_message_template(user_query, model_schema) | |
| text = TOKENIZER.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = TOKENIZER( | |
| text, | |
| return_tensors="pt", | |
| padding=False | |
| ) | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = inputs["attention_mask"].to(device) | |
| input_length = input_ids.shape[1] | |
| with torch.inference_mode(): | |
| outputs = MODEL.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=64, | |
| do_sample=False, | |
| use_cache=True, | |
| pad_token_id=TOKENIZER.eos_token_id, | |
| ) | |
| generated_tokens = outputs[0][input_length:] | |
| response = TOKENIZER.decode( | |
| generated_tokens, | |
| skip_special_tokens=True | |
| ) | |
| return response.strip() | |