johneze commited on
Commit
8cf79dd
Β·
verified Β·
1 Parent(s): 613d990

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +33 -123
app.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
  Chichewa Text-to-SQL β€” HuggingFace Space
3
  - Generates SQL from Chichewa/English questions using the fine-tuned model
4
- - Matches question against the training dataset (baseline retrieval)
5
  - Executes the SQL against the bundled SQLite database and returns results
6
  """
7
  from __future__ import annotations
@@ -21,12 +21,14 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
21
 
22
  MODEL_ID = "johneze/Llama-3.1-8B-Instruct-chichewa-text2sql"
23
 
24
- # Files uploaded alongside app.py into the Space root
25
- _HERE = Path(__file__).parent
26
  DATA_PATH = _HERE / "data" / "all.json"
27
  DB_PATH = _HERE / "data" / "database" / "chichewa_text2sql.db"
28
 
29
- FORBIDDEN = {"insert","update","delete","drop","alter","attach","pragma","create","replace","truncate"}
 
 
 
30
 
31
  # ── Dataset ────────────────────────────────────────────────────────────────
32
  _examples: list = []
@@ -57,9 +59,20 @@ def find_match(question: str, language: str):
57
  return None, 0.0, "none"
58
 
59
 
60
- # ── SQL execution ──────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
61
  def run_query(sql: str):
62
- """Validate and run a SELECT query. Returns (DataFrame | None, error_str | None)."""
63
  s = sql.strip().rstrip(";")
64
  if not s.lower().startswith("select"):
65
  return None, "Only SELECT statements are allowed."
@@ -82,8 +95,8 @@ def run_query(sql: str):
82
  conn.close()
83
 
84
 
85
- # ── Model loading ──────────────────────────────────────────────────────────
86
- print("Downloading model weights to cache …")
87
  _model_cache = snapshot_download(repo_id=MODEL_ID)
88
  print(f"Model cached at: {_model_cache}")
89
 
@@ -91,22 +104,10 @@ tokenizer = AutoTokenizer.from_pretrained(_model_cache)
91
  _pipe = None
92
 
93
 
94
- def extract_sql(text: str) -> str:
95
- match = re.search(r"(?is)select\s.+", text)
96
- if not match:
97
- return text.strip()
98
- sql = match.group(0)
99
- for sep in [";", "\n"]:
100
- if sep in sql:
101
- sql = sql.split(sep)[0]
102
- return sql.strip() + ";"
103
-
104
-
105
  @spaces.GPU(duration=300)
106
  def generate_sql(question: str, language: str = "ny"):
107
- """
108
- Returns (sql: str, match_info: str, results: pd.DataFrame)
109
- """
110
  global _pipe
111
  if _pipe is None:
112
  model = AutoModelForCausalLM.from_pretrained(
@@ -129,14 +130,13 @@ def generate_sql(question: str, language: str = "ny"):
129
  },
130
  {"role": "user", "content": f"Language: {lang_name}\nQuestion: {question}"},
131
  ]
132
-
133
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
134
  out = _pipe(prompt, max_new_tokens=128, do_sample=False,
135
  pad_token_id=tokenizer.eos_token_id)[0]["generated_text"]
136
  generated = out[len(prompt):] if out.startswith(prompt) else out
137
  sql = extract_sql(generated)
138
 
139
- # ── Dataset match ──────────────────────────────────────────────────────
140
  example, score, mode = find_match(question, language)
141
  if example:
142
  match_info = (
@@ -144,13 +144,13 @@ def generate_sql(question: str, language: str = "ny"):
144
  f"**ny:** {example.get('question_ny', '')}\n\n"
145
  f"**en:** {example.get('question_en', '')}\n\n"
146
  f"**Dataset SQL:** `{example.get('sql_statement', '')}`\n\n"
147
- f"**Table:** {example.get('table', '')}  |  "
148
  f"**Difficulty:** {example.get('difficulty_level', '')}"
149
  )
150
  else:
151
  match_info = "_No close match found in the dataset._"
152
 
153
- # ── Execute SQL ────────────────────────────────────────────────────────
154
  df, err = run_query(sql)
155
  if err:
156
  results = pd.DataFrame([{"error": err}])
@@ -164,7 +164,11 @@ def generate_sql(question: str, language: str = "ny"):
164
 
165
  # ── Gradio UI ──────────────────────────────────────────────────────────────
166
  with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
167
- gr.Markdown("# 🌍 Chichewa Text-to-SQL\nEnter a question in Chichewa or English to generate SQL, match it against the dataset, and run it on the database.")
 
 
 
 
168
 
169
  with gr.Row():
170
  question_box = gr.Textbox(
@@ -177,7 +181,7 @@ with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
177
  submit_btn = gr.Button("Generate SQL & Run", variant="primary")
178
 
179
  sql_output = gr.Code(label="Generated SQL", language="sql")
180
- match_output = gr.Markdown(label="Dataset Match")
181
  result_output = gr.Dataframe(label="Query Results", wrap=True)
182
 
183
  submit_btn.click(
@@ -196,98 +200,4 @@ with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
196
  inputs=[question_box, language_box],
197
  )
198
 
199
-
200
- if __name__ == "__main__":
201
- demo.launch()
202
-
203
-
204
- def extract_sql(text: str) -> str:
205
- match = re.search(r"(?is)select\s.+", text)
206
- if not match:
207
- return text.strip()
208
- sql = match.group(0)
209
- for sep in [";", "\n"]:
210
- if sep in sql:
211
- sql = sql.split(sep)[0]
212
- return sql.strip() + ";"
213
-
214
-
215
- @spaces.GPU
216
- def generate_sql(question: str, language: str = "ny") -> str:
217
- """
218
- Generate SQL from a Chichewa or English question.
219
- language: 'ny' for Chichewa, 'en' for English.
220
- Returns a SQL SELECT statement.
221
- """
222
- lang_name = "Chichewa" if language == "ny" else "English"
223
-
224
- messages = [
225
- {
226
- "role": "system",
227
- "content": (
228
- "You are an expert Text-to-SQL model for a SQLite database "
229
- "with the following tables: production, population, food_insecurity, "
230
- "commodity_prices, mse_daily. "
231
- "Given a natural language question, generate ONE valid SQL SELECT query. "
232
- "Return ONLY the SQL query, no explanation."
233
- ),
234
- },
235
- {
236
- "role": "user",
237
- "content": f"Language: {lang_name}\nQuestion: {question}",
238
- },
239
- ]
240
-
241
- prompt = tokenizer.apply_chat_template(
242
- messages, tokenize=False, add_generation_prompt=True
243
- )
244
-
245
- out = pipe(
246
- prompt,
247
- max_new_tokens=128,
248
- do_sample=False,
249
- pad_token_id=tokenizer.eos_token_id,
250
- )[0]["generated_text"]
251
-
252
- generated = out[len(prompt):] if out.startswith(prompt) else out
253
- return extract_sql(generated)
254
-
255
-
256
- # ── Gradio UI ──────────────────────────────────────────────────────────────
257
- with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
258
- gr.Markdown("# 🌍 Chichewa Text-to-SQL\nEnter a question in Chichewa or English to generate SQL.")
259
-
260
- with gr.Row():
261
- question_box = gr.Textbox(
262
- label="Question",
263
- placeholder="Ndi boma liti komwe anakolola chimanga chambiri?",
264
- lines=3,
265
- )
266
- language_box = gr.Radio(
267
- ["ny", "en"],
268
- value="ny",
269
- label="Language",
270
- )
271
-
272
- submit_btn = gr.Button("Generate SQL", variant="primary")
273
- sql_output = gr.Code(label="Generated SQL", language="sql")
274
-
275
- submit_btn.click(
276
- fn=generate_sql,
277
- inputs=[question_box, language_box],
278
- outputs=sql_output,
279
- )
280
-
281
- gr.Examples(
282
- examples=[
283
- ["Ndi boma liti komwe anakolola chimanga chambiri?", "ny"],
284
- ["Which district produced the most Maize?", "en"],
285
- ["Ndi anthu angati ku Lilongwe?", "ny"],
286
- ["What is the food insecurity level in Nsanje?", "en"],
287
- ],
288
- inputs=[question_box, language_box],
289
- )
290
-
291
-
292
- if __name__ == "__main__":
293
- demo.launch()
 
1
  """
2
  Chichewa Text-to-SQL β€” HuggingFace Space
3
  - Generates SQL from Chichewa/English questions using the fine-tuned model
4
+ - Matches question against the dataset (fuzzy retrieval)
5
  - Executes the SQL against the bundled SQLite database and returns results
6
  """
7
  from __future__ import annotations
 
21
 
22
  MODEL_ID = "johneze/Llama-3.1-8B-Instruct-chichewa-text2sql"
23
 
24
+ _HERE = Path(__file__).parent
 
25
  DATA_PATH = _HERE / "data" / "all.json"
26
  DB_PATH = _HERE / "data" / "database" / "chichewa_text2sql.db"
27
 
28
+ FORBIDDEN = {
29
+ "insert", "update", "delete", "drop", "alter",
30
+ "attach", "pragma", "create", "replace", "truncate",
31
+ }
32
 
33
  # ── Dataset ────────────────────────────────────────────────────────────────
34
  _examples: list = []
 
59
  return None, 0.0, "none"
60
 
61
 
62
+ # ── SQL helpers ────────────────────────────────────────────────────────────
63
+ def extract_sql(text: str) -> str:
64
+ m = re.search(r"(?is)select\s.+", text)
65
+ if not m:
66
+ return text.strip()
67
+ sql = m.group(0)
68
+ for sep in [";", "\n"]:
69
+ if sep in sql:
70
+ sql = sql.split(sep)[0]
71
+ return sql.strip() + ";"
72
+
73
+
74
  def run_query(sql: str):
75
+ """Returns (DataFrame | None, error_str | None)."""
76
  s = sql.strip().rstrip(";")
77
  if not s.lower().startswith("select"):
78
  return None, "Only SELECT statements are allowed."
 
95
  conn.close()
96
 
97
 
98
+ # ── Model (pre-download weights at startup, load into GPU on first call) ───
99
+ print("Downloading model weights to cache ...")
100
  _model_cache = snapshot_download(repo_id=MODEL_ID)
101
  print(f"Model cached at: {_model_cache}")
102
 
 
104
  _pipe = None
105
 
106
 
107
+ # ── Main function ──────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
108
  @spaces.GPU(duration=300)
109
  def generate_sql(question: str, language: str = "ny"):
110
+ """Returns (sql, match_info_markdown, results_dataframe)."""
 
 
111
  global _pipe
112
  if _pipe is None:
113
  model = AutoModelForCausalLM.from_pretrained(
 
130
  },
131
  {"role": "user", "content": f"Language: {lang_name}\nQuestion: {question}"},
132
  ]
 
133
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
134
  out = _pipe(prompt, max_new_tokens=128, do_sample=False,
135
  pad_token_id=tokenizer.eos_token_id)[0]["generated_text"]
136
  generated = out[len(prompt):] if out.startswith(prompt) else out
137
  sql = extract_sql(generated)
138
 
139
+ # Dataset match
140
  example, score, mode = find_match(question, language)
141
  if example:
142
  match_info = (
 
144
  f"**ny:** {example.get('question_ny', '')}\n\n"
145
  f"**en:** {example.get('question_en', '')}\n\n"
146
  f"**Dataset SQL:** `{example.get('sql_statement', '')}`\n\n"
147
+ f"**Table:** {example.get('table', '')} | "
148
  f"**Difficulty:** {example.get('difficulty_level', '')}"
149
  )
150
  else:
151
  match_info = "_No close match found in the dataset._"
152
 
153
+ # Execute SQL
154
  df, err = run_query(sql)
155
  if err:
156
  results = pd.DataFrame([{"error": err}])
 
164
 
165
  # ── Gradio UI ──────────────────────────────────────────────────────────────
166
  with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
167
+ gr.Markdown(
168
+ "# Chichewa Text-to-SQL\n"
169
+ "Enter a question in **Chichewa** or **English** to generate SQL, "
170
+ "match it against the dataset, and run it on the database."
171
+ )
172
 
173
  with gr.Row():
174
  question_box = gr.Textbox(
 
181
  submit_btn = gr.Button("Generate SQL & Run", variant="primary")
182
 
183
  sql_output = gr.Code(label="Generated SQL", language="sql")
184
+ match_output = gr.Markdown()
185
  result_output = gr.Dataframe(label="Query Results", wrap=True)
186
 
187
  submit_btn.click(
 
200
  inputs=[question_box, language_box],
201
  )
202
 
203
+ demo.launch()