dvwn commited on
Commit
5ac1ec2
·
1 Parent(s): ccb82de

Update HF Model Engine > hf_engine.py

Browse files

1. Implement custom wrapper for LangChain compatibility for later agent implementation.
2. Use InferenceClient to initialize the model with Inference API Provider
3. Tested with local code block
- Successfully loaded

src/nl2sql/__pycache__/hf_engine.cpython-313.pyc CHANGED
Binary files a/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc and b/src/nl2sql/__pycache__/hf_engine.cpython-313.pyc differ
 
src/nl2sql/hf_engine.py CHANGED
@@ -1,98 +1,63 @@
1
- #"""Hugging Face inference helpers for SQL generation."""
2
-
3
  import os
4
- import re
5
-
6
- from dotenv import load_dotenv
7
  from huggingface_hub import InferenceClient
8
-
9
-
10
- load_dotenv()
11
- hf_token = os.getenv("HF_TOKEN")
12
- if not hf_token:
13
- raise ValueError("Token Not Found!")
14
-
15
- client = InferenceClient(api_key=hf_token)
16
- MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
17
-
18
-
19
- def _build_messages(question: str, schema_context: str):
20
- system_content = (
21
- "You are an expert SQLite assistant that converts natural language into one "
22
- "executable SQLite query.\n"
23
- "Rules:\n"
24
- "1. Use only tables, columns, and join paths present in the provided schema.\n"
25
- "2. Generate valid SQLite syntax only.\n"
26
- "3. Prefer exact column names from the schema, never invent columns.\n"
27
- "4. Use explicit JOIN conditions when multiple tables are required.\n"
28
- "5. Use GROUP BY for aggregates by entity, HAVING for aggregate filters, "
29
- "ORDER BY for ranking, and LIMIT for top-N requests.\n"
30
- "6. Return SQL only. No markdown, explanations, comments, or chain-of-thought.\n"
31
- "7. If a join is needed, use short aliases that remain readable.\n"
32
- "8. Produce a single SELECT statement."
33
- )
34
-
35
- user_content = f"""Database schema:
36
- {schema_context}
37
-
38
- Question:
39
- {question}
40
-
41
- Write the SQLite query that answers the question. Return only the SQL query."""
42
-
43
- return [
44
- {"role": "system", "content": system_content},
45
- {"role": "user", "content": user_content},
46
- ]
47
-
48
-
49
- def _extract_sql(raw_response: str) -> str:
50
- text = raw_response.strip()
51
- fenced_match = re.search(r"```(?:sql)?\s*(.*?)```", text, flags=re.IGNORECASE | re.DOTALL)
52
- if fenced_match:
53
- text = fenced_match.group(1).strip()
54
-
55
- statement_match = re.search(
56
- r"(?is)\b(WITH|SELECT)\b.*?(;|$)",
57
- text,
58
- )
59
- if statement_match:
60
- text = statement_match.group(0).strip()
61
-
62
- lines = [
63
- line.strip()
64
- for line in text.splitlines()
65
- if line.strip() and not line.strip().startswith(("--", "#"))
66
- ]
67
- sql = " ".join(lines).strip()
68
- if sql and not sql.endswith(";"):
69
- sql = f"{sql};"
70
- return sql
71
-
72
-
73
- def generate_sql(question, ddl):
74
- try:
75
- completion = client.chat.completions.create(
76
- model=MODEL_ID,
77
- messages=_build_messages(question, ddl),
78
- max_tokens=220,
79
- temperature=0,
80
  )
81
- raw_response = completion.choices[0].message.content or ""
82
- sql = _extract_sql(raw_response)
83
- return sql or raw_response.strip()
84
- except Exception as error:
85
- return f"Error: {error}"
86
-
87
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  if __name__ == "__main__":
89
- my_ddl = "CREATE TABLE tracks (id INTEGER PRIMARY KEY, title TEXT, genre TEXT);"
90
- my_question = "How many tracks are there in each genre?"
91
 
92
- print("Generating SQL query via Featherless AI...")
93
  try:
94
- result = generate_sql(my_question, my_ddl)
95
- print("-" * 20)
96
- print(result)
97
- except Exception as error:
98
- print(f"An error occurred: {error}")
 
 
1
+ # Path: src/nl2sql/hf_engine.py
2
+ # This module defines the HuggingFace-based engine for generating SQL queries from natural language questions.
3
  import os
 
 
 
4
  from huggingface_hub import InferenceClient
5
+ from langchain_core.language_models.llms import LLM
6
+ from typing import Any, List, Optional
7
+
8
+ # Default Model
9
+ DEFAULT_MODEL_ID = "defog/llama-3-sqlcoder-8b:featherless-ai"
10
+
11
+ # Custom LangChain wrapper for HuggingFace Inference API
12
+ class HFChatWrapper(LLM):
13
+ """
14
+ Custom LLM wrapper for HuggingFace Inference API to maintain compatibility with LangChain's LLM interface.
15
+ """
16
+ client: Any
17
+ model_id: str
18
+
19
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
20
+ completion = self.client.chat.completions.create(
21
+ model = self.model_id,
22
+ messages = [
23
+ {"role": "user", "content": prompt}
24
+ ],
25
+ temperature = 0.0,
26
+ max_tokens = 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
+ return completion.choices[0].message.content
29
+
30
+ @property
31
+ def _llm_type(self) -> str:
32
+ return "huggingface_inference_client"
33
+
34
+ # Initialize the HuggingFace endpoint using the InferenceClient
35
+ def get_llm(model_id: str = DEFAULT_MODEL_ID):
36
+ """
37
+ Initializes the HuggingFace InferenceClient and returns an LLM instance for generating SQL queries.
38
+ """
39
+ # Load HuggingFace API token from environment variable
40
+ hf_token = os.getenv("HF_TOKEN")
41
+ if not hf_token:
42
+ raise ValueError("HuggingFace API token not found!")
43
+
44
+ print(f"Initializing HuggingFace InferenceClient with model: {model_id}")
45
+
46
+ # Initialize the HuggingFace InferenceClient
47
+ client = InferenceClient(api_key=hf_token)
48
+ llm = HFChatWrapper(client=client, model_id=model_id)
49
+
50
+ return llm
51
+
52
+ # Local Test block
53
  if __name__ == "__main__":
54
+ from dotenv import load_dotenv
55
+ load_dotenv()
56
 
 
57
  try:
58
+ test_llm = get_llm()
59
+ print("Model loaded successfully! Running a quick ping...")
60
+ response = test_llm.invoke("Write a single SQL statement to count all rows in a table named 'Employee'.")
61
+ print(f"\nResponse:\n{response}")
62
+ except Exception as e:
63
+ print(f"Error during LLM initialization or invocation: {e}")