dvwn commited on
Commit
e06da36
·
1 Parent(s): a4607e0

hf_engine.py & sql_agent.py version 1.1.0

Browse files
backend/src/nl2sql/hf_engine.py CHANGED
@@ -6,10 +6,6 @@ from langchain_huggingface import HuggingFaceEndpoint
6
  from langchain_core.language_models.llms import LLM
7
  from typing import Any, List, Optional
8
 
9
- # Default Model
10
- # DEFAULT_MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
11
- # DEFAULT_MODEL_ID = "defog/sqlcoder-7b-2"
12
- # DEFAULT_MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct:featherless-ai"
13
  # Model Registry: Add several model to be tested
14
  MODEL_REGISTRY = {
15
  "defog/sqlcoder-7b-2": "text",
@@ -19,7 +15,7 @@ MODEL_REGISTRY = {
19
  #"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B:featherless-ai": "chat"
20
  }
21
 
22
- ACTIVE_MODEL_ID = "Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai"
23
 
24
  # Custom LangChain wrapper for HuggingFace Inference API
25
  class HFChatWrapper(LLM):
@@ -43,9 +39,13 @@ class HFChatWrapper(LLM):
43
  @property
44
  def _llm_type(self) -> str:
45
  return "huggingface_inference_client"
46
-
 
 
 
 
47
  # Initialize the HuggingFace endpoint using the InferenceClient
48
- def get_llm(model_id: str = ACTIVE_MODEL_ID):
49
  """
50
  Automatically detects the model type and returns the correct LangChain interface.
51
  Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
@@ -55,16 +55,22 @@ def get_llm(model_id: str = ACTIVE_MODEL_ID):
55
  if not hf_token:
56
  raise ValueError("HuggingFace API token not found!")
57
 
58
- model_type = MODEL_REGISTRY.get(model_id, "chat")
59
- print(f"Initializing HuggingFace InferenceClient with model: {model_id}")
 
 
 
 
 
 
60
 
61
  if model_type == "chat":
62
  client = InferenceClient(api_key=hf_token)
63
- return HFChatWrapper(client=client, model_id=model_id)
64
  elif model_type == "text":
65
  # Route to standard Text Generation API
66
  return HuggingFaceEndpoint(
67
- repo_id=model_id,
68
  task="text-generation",
69
  max_new_tokens=512,
70
  temperature=0.0,
@@ -77,7 +83,7 @@ def get_llm(model_id: str = ACTIVE_MODEL_ID):
77
 
78
  # Initialize the HuggingFace InferenceClient
79
  #client = InferenceClient(api_key=hf_token)
80
- #llm = HFChatWrapper(client=client, model_id=model_id)
81
 
82
  #return llm
83
 
 
6
  from langchain_core.language_models.llms import LLM
7
  from typing import Any, List, Optional
8
 
 
 
 
 
9
  # Model Registry: Add several model to be tested
10
  MODEL_REGISTRY = {
11
  "defog/sqlcoder-7b-2": "text",
 
15
  #"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B:featherless-ai": "chat"
16
  }
17
 
18
+ DEFAULT_MODEL_ID = "Qwen/Qwen2.5-Coder-32B-Instruct:featherless-ai"
19
 
20
  # Custom LangChain wrapper for HuggingFace Inference API
21
  class HFChatWrapper(LLM):
 
39
  @property
40
  def _llm_type(self) -> str:
41
  return "huggingface_inference_client"
42
+
43
+ def get_models() -> List[str]:
44
+ """Utility to return all model IDs available in the MODEL_REGISTRY."""
45
+ return list(MODEL_REGISTRY.keys())
46
+
47
  # Initialize the HuggingFace endpoint using the InferenceClient
48
+ def get_llm(model_id: str = DEFAULT_MODEL_ID):
49
  """
50
  Automatically detects the model type and returns the correct LangChain interface.
51
  Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
 
55
  if not hf_token:
56
  raise ValueError("HuggingFace API token not found!")
57
 
58
+ # Determine the model type based on the MODEL_REGISTRY
59
+ active_model = model_id if model_id else DEFAULT_MODEL_ID
60
+
61
+ if active_model not in MODEL_REGISTRY:
62
+ print(f"Warning: Model '{active_model}' not found in MODEL_REGISTRY. Defaulting to 'chat' type.")
63
+
64
+ model_type = MODEL_REGISTRY.get(active_model, "chat")
65
+ print(f"Initializing HuggingFace InferenceClient with model: {active_model}")
66
 
67
  if model_type == "chat":
68
  client = InferenceClient(api_key=hf_token)
69
+ return HFChatWrapper(client=client, model_id=active_model)
70
  elif model_type == "text":
71
  # Route to standard Text Generation API
72
  return HuggingFaceEndpoint(
73
+ repo_id=active_model,
74
  task="text-generation",
75
  max_new_tokens=512,
76
  temperature=0.0,
 
83
 
84
  # Initialize the HuggingFace InferenceClient
85
  #client = InferenceClient(api_key=hf_token)
86
+ #llm = HFChatWrapper(client=client, model_id=active_model)
87
 
88
  #return llm
89
 
backend/src/nl2sql/sql_agent.py CHANGED
@@ -84,7 +84,7 @@ def clean_sql(raw_sql: str) -> str:
84
  return cleaned.strip()
85
 
86
  # Function to handle NL2SQL conversion
87
- def nl2sql_agent(user_question: str, max_retries: int = 3) -> dict:
88
  """
89
  Complete flow execution with Auto-correction:
90
  Get Schema context -> Generate SQL query -> Execute SQL query -> If Error, Refine & Retry ->Return results
@@ -94,7 +94,7 @@ def nl2sql_agent(user_question: str, max_retries: int = 3) -> dict:
94
  schema = get_schema_context(question = user_question)
95
 
96
  # Generate SQL query using the schema context and user question
97
- llm = get_llm()
98
 
99
  # LangChain Pipeline: Pipe prompt into LLM
100
  chain = prompt_template | llm
@@ -107,7 +107,7 @@ def nl2sql_agent(user_question: str, max_retries: int = 3) -> dict:
107
  # Auto-correction Loop
108
  for attempt in range(1, max_retries + 1):
109
  if attempt == 1:
110
- print("Generating initial SQL query...")
111
  raw_response = chain.invoke({
112
  "schema": schema,
113
  "question": user_question
@@ -122,50 +122,50 @@ def nl2sql_agent(user_question: str, max_retries: int = 3) -> dict:
122
  "error_message": error_message
123
  })
124
 
125
- # Parse & clean the generated SQL query
126
- generated_sql = clean_sql(raw_response)
127
- current_sql = generated_sql
128
- print(f"Generated SQL: \n{generated_sql}")
129
-
130
- # Execute the generated SQL query and fetch results
131
- connection = get_db_connection()
132
- if not connection:
133
- return {
134
- "query": generated_sql,
135
- "error": "Could not establish database connection",
136
- "status": "failed"
137
- }
 
 
 
 
 
138
 
139
- try:
140
- cursor = connection.cursor()
141
- cursor.execute(generated_sql)
142
- results = cursor.fetchall()
143
-
144
- if attempt > 1:
145
- print(f"SQL query executed successfully after {attempt} attempts.")
146
-
147
- # Generate natural language response based on the results
148
- print("Generating natural language response based on query results...")
149
- nl_response = nl_chain.invoke({
150
- "question": user_question,
151
- "results": str(results)
152
- })
153
-
154
- return {
155
- "query": generated_sql,
156
- "results": results,
157
- "nl_response": nl_response,
158
- "status": "success",
159
- "attempts": attempt
160
- }
161
- except Exception as e:
162
- error_message = str(e)
163
- print(f"Error executing SQL: {error_message}")
164
-
165
- if attempt == max_retries:
166
- print("Max retries reached. Returning error.")
167
- finally:
168
- connection.close()
169
 
170
  return {
171
  "query": current_sql,
 
84
  return cleaned.strip()
85
 
86
  # Function to handle NL2SQL conversion
87
+ def nl2sql_agent(user_question: str, max_retries: int = 3, model_id: str = None) -> dict:
88
  """
89
  Complete flow execution with Auto-correction:
90
  Get Schema context -> Generate SQL query -> Execute SQL query -> If Error, Refine & Retry ->Return results
 
94
  schema = get_schema_context(question = user_question)
95
 
96
  # Generate SQL query using the schema context and user question
97
+ llm = get_llm(model_id=model_id)
98
 
99
  # LangChain Pipeline: Pipe prompt into LLM
100
  chain = prompt_template | llm
 
107
  # Auto-correction Loop
108
  for attempt in range(1, max_retries + 1):
109
  if attempt == 1:
110
+ print(f"Generating initial SQL query using {model_id or 'default model'}...")
111
  raw_response = chain.invoke({
112
  "schema": schema,
113
  "question": user_question
 
122
  "error_message": error_message
123
  })
124
 
125
+ # Parse & clean the generated SQL query
126
+ generated_sql = clean_sql(raw_response)
127
+ current_sql = generated_sql
128
+ print(f"Generated SQL: \n{generated_sql}")
129
+
130
+ # Execute the generated SQL query and fetch results
131
+ connection = get_db_connection()
132
+ if not connection:
133
+ return {
134
+ "query": generated_sql,
135
+ "error": "Could not establish database connection",
136
+ "status": "failed"
137
+ }
138
+
139
+ try:
140
+ cursor = connection.cursor()
141
+ cursor.execute(generated_sql)
142
+ results = cursor.fetchall()
143
 
144
+ if attempt > 1:
145
+ print(f"SQL query executed successfully after {attempt} attempts.")
146
+
147
+ # Generate natural language response based on the results
148
+ print("Generating natural language response based on query results...")
149
+ nl_response = nl_chain.invoke({
150
+ "question": user_question,
151
+ "results": str(results)
152
+ })
153
+
154
+ return {
155
+ "query": generated_sql,
156
+ "results": results,
157
+ "nl_response": nl_response,
158
+ "status": "success",
159
+ "attempts": attempt
160
+ }
161
+ except Exception as e:
162
+ error_message = str(e)
163
+ print(f"Error executing SQL: {error_message}")
164
+
165
+ if attempt == max_retries:
166
+ print("Max retries reached. Returning error.")
167
+ finally:
168
+ connection.close()
 
 
 
 
 
169
 
170
  return {
171
  "query": current_sql,