johneze commited on
Commit
aa02a35
Β·
verified Β·
1 Parent(s): e06e3a1

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +112 -42
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,31 +1,93 @@
1
  """
2
  Chichewa Text-to-SQL β€” HuggingFace Space
3
- Loads johneze/Llama-3.1-8B-Instruct-chichewa-text2sql and exposes a
4
- Gradio API endpoint that the Streamlit app (or anyone) can call.
5
- Uses ZeroGPU for free GPU access on HF Spaces.
6
  """
7
  from __future__ import annotations
8
 
 
9
  import re
 
 
 
 
10
  import spaces
11
  import gradio as gr
12
  import torch
 
13
  from huggingface_hub import snapshot_download
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
15
 
16
  MODEL_ID = "johneze/Llama-3.1-8B-Instruct-chichewa-text2sql"
17
 
18
- # Pre-download all model files to disk at startup (no GPU required).
19
- # When @spaces.GPU activates, from_pretrained reads from the local cache
20
- # instead of downloading β€” slashing first-call latency significantly.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  print("Downloading model weights to cache …")
22
  _model_cache = snapshot_download(repo_id=MODEL_ID)
23
  print(f"Model cached at: {_model_cache}")
24
 
25
- # Tokenizer is tiny β€” safe to load at startup without a GPU
26
  tokenizer = AutoTokenizer.from_pretrained(_model_cache)
27
-
28
- # Model is loaded lazily on the FIRST call inside @spaces.GPU where CUDA is live.
29
  _pipe = None
30
 
31
 
@@ -41,15 +103,12 @@ def extract_sql(text: str) -> str:
41
 
42
 
43
  @spaces.GPU(duration=300)
44
- def generate_sql(question: str, language: str = "ny") -> str:
45
  """
46
- Generate SQL from a Chichewa or English question.
47
- language: 'ny' for Chichewa, 'en' for English.
48
- Returns a SQL SELECT statement.
49
  """
50
  global _pipe
51
  if _pipe is None:
52
- # Weights already on disk β€” this only loads into VRAM (~30-60s)
53
  model = AutoModelForCausalLM.from_pretrained(
54
  _model_cache,
55
  dtype=torch.bfloat16,
@@ -58,42 +117,54 @@ def generate_sql(question: str, language: str = "ny") -> str:
58
  _pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
59
 
60
  lang_name = "Chichewa" if language == "ny" else "English"
61
-
62
  messages = [
63
  {
64
  "role": "system",
65
  "content": (
66
  "You are an expert Text-to-SQL model for a SQLite database "
67
- "with the following tables: production, population, food_insecurity, "
68
  "commodity_prices, mse_daily. "
69
- "Given a natural language question, generate ONE valid SQL SELECT query. "
70
- "Return ONLY the SQL query, no explanation."
71
  ),
72
  },
73
- {
74
- "role": "user",
75
- "content": f"Language: {lang_name}\nQuestion: {question}",
76
- },
77
  ]
78
 
79
- prompt = tokenizer.apply_chat_template(
80
- messages, tokenize=False, add_generation_prompt=True
81
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- out = _pipe(
84
- prompt,
85
- max_new_tokens=128,
86
- do_sample=False,
87
- pad_token_id=tokenizer.eos_token_id,
88
- )[0]["generated_text"]
 
 
89
 
90
- generated = out[len(prompt):] if out.startswith(prompt) else out
91
- return extract_sql(generated)
92
 
93
 
94
  # ── Gradio UI ──────────────────────────────────────────────────────────────
95
  with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
96
- gr.Markdown("# Chichewa Text-to-SQL\nEnter a question in Chichewa or English to generate SQL.")
97
 
98
  with gr.Row():
99
  question_box = gr.Textbox(
@@ -101,19 +172,18 @@ with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
101
  placeholder="Ndi boma liti komwe anakolola chimanga chambiri?",
102
  lines=3,
103
  )
104
- language_box = gr.Radio(
105
- ["ny", "en"],
106
- value="ny",
107
- label="Language",
108
- )
109
 
110
- submit_btn = gr.Button("Generate SQL", variant="primary")
111
- sql_output = gr.Code(label="Generated SQL", language="sql")
 
 
 
112
 
113
  submit_btn.click(
114
  fn=generate_sql,
115
  inputs=[question_box, language_box],
116
- outputs=sql_output,
117
  )
118
 
119
  gr.Examples(
 
1
  """
2
  Chichewa Text-to-SQL β€” HuggingFace Space
3
+ - Generates SQL from Chichewa/English questions using the fine-tuned model
4
+ - Matches question against the training dataset (baseline retrieval)
5
+ - Executes the SQL against the bundled SQLite database and returns results
6
  """
7
  from __future__ import annotations
8
 
9
+ import json
10
  import re
11
+ import sqlite3
12
+ import difflib
13
+ from pathlib import Path
14
+
15
  import spaces
16
  import gradio as gr
17
  import torch
18
+ import pandas as pd
19
  from huggingface_hub import snapshot_download
20
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
21
 
22
  MODEL_ID = "johneze/Llama-3.1-8B-Instruct-chichewa-text2sql"
23
 
24
+ # Files uploaded alongside app.py into the Space root
25
+ _HERE = Path(__file__).parent
26
+ DATA_PATH = _HERE / "data" / "all.json"
27
+ DB_PATH = _HERE / "data" / "database" / "chichewa_text2sql.db"
28
+
29
+ FORBIDDEN = {"insert","update","delete","drop","alter","attach","pragma","create","replace","truncate"}
30
+
31
+ # ── Dataset ────────────────────────────────────────────────────────────────
32
+ _examples: list = []
33
+ if DATA_PATH.exists():
34
+ with DATA_PATH.open("r", encoding="utf-8") as _f:
35
+ _examples = json.load(_f)
36
+ print(f"Loaded {len(_examples)} dataset examples.")
37
+ else:
38
+ print(f"WARNING: dataset not found at {DATA_PATH}")
39
+
40
+
41
+ def _norm(t: str) -> str:
42
+ return " ".join(t.lower().strip().split())
43
+
44
+
45
+ def find_match(question: str, language: str):
46
+ key = "question_ny" if language == "ny" else "question_en"
47
+ q = _norm(question)
48
+ for ex in _examples:
49
+ if _norm(ex.get(key, "")) == q:
50
+ return ex, 1.0, "exact"
51
+ corpus = [_norm(ex.get(key, "")) for ex in _examples]
52
+ hits = difflib.get_close_matches(q, corpus, n=1, cutoff=0.5)
53
+ if hits:
54
+ idx = corpus.index(hits[0])
55
+ score = difflib.SequenceMatcher(None, q, hits[0]).ratio()
56
+ return _examples[idx], round(score, 3), "fuzzy"
57
+ return None, 0.0, "none"
58
+
59
+
60
+ # ── SQL execution ──────────────────────────────────────────────────────────
61
+ def run_query(sql: str):
62
+ """Validate and run a SELECT query. Returns (DataFrame | None, error_str | None)."""
63
+ s = sql.strip().rstrip(";")
64
+ if not s.lower().startswith("select"):
65
+ return None, "Only SELECT statements are allowed."
66
+ if ";" in s:
67
+ return None, "Multiple statements not allowed."
68
+ if any(kw in s.lower() for kw in FORBIDDEN):
69
+ return None, "Forbidden keyword detected."
70
+ if not DB_PATH.exists():
71
+ return None, f"Database not found at {DB_PATH}"
72
+ conn = sqlite3.connect(DB_PATH)
73
+ conn.row_factory = sqlite3.Row
74
+ try:
75
+ rows = conn.execute(sql).fetchall()
76
+ if not rows:
77
+ return pd.DataFrame(), None
78
+ return pd.DataFrame([dict(r) for r in rows]), None
79
+ except Exception as exc:
80
+ return None, str(exc)
81
+ finally:
82
+ conn.close()
83
+
84
+
85
+ # ── Model loading ──────────────────────────────────────────────────────────
86
  print("Downloading model weights to cache …")
87
  _model_cache = snapshot_download(repo_id=MODEL_ID)
88
  print(f"Model cached at: {_model_cache}")
89
 
 
90
  tokenizer = AutoTokenizer.from_pretrained(_model_cache)
 
 
91
  _pipe = None
92
 
93
 
 
103
 
104
 
105
  @spaces.GPU(duration=300)
106
+ def generate_sql(question: str, language: str = "ny"):
107
  """
108
+ Returns (sql: str, match_info: str, results: pd.DataFrame)
 
 
109
  """
110
  global _pipe
111
  if _pipe is None:
 
112
  model = AutoModelForCausalLM.from_pretrained(
113
  _model_cache,
114
  dtype=torch.bfloat16,
 
117
  _pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
118
 
119
  lang_name = "Chichewa" if language == "ny" else "English"
 
120
  messages = [
121
  {
122
  "role": "system",
123
  "content": (
124
  "You are an expert Text-to-SQL model for a SQLite database "
125
+ "with tables: production, population, food_insecurity, "
126
  "commodity_prices, mse_daily. "
127
+ "Generate ONE valid SQL SELECT query. Return ONLY the SQL, no explanation."
 
128
  ),
129
  },
130
+ {"role": "user", "content": f"Language: {lang_name}\nQuestion: {question}"},
 
 
 
131
  ]
132
 
133
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
134
+ out = _pipe(prompt, max_new_tokens=128, do_sample=False,
135
+ pad_token_id=tokenizer.eos_token_id)[0]["generated_text"]
136
+ generated = out[len(prompt):] if out.startswith(prompt) else out
137
+ sql = extract_sql(generated)
138
+
139
+ # ── Dataset match ──────────────────────────────────────────────────────
140
+ example, score, mode = find_match(question, language)
141
+ if example:
142
+ match_info = (
143
+ f"**Match:** {mode} (score: {score})\n\n"
144
+ f"**ny:** {example.get('question_ny', '')}\n\n"
145
+ f"**en:** {example.get('question_en', '')}\n\n"
146
+ f"**Dataset SQL:** `{example.get('sql_statement', '')}`\n\n"
147
+ f"**Table:** {example.get('table', '')}  |  "
148
+ f"**Difficulty:** {example.get('difficulty_level', '')}"
149
+ )
150
+ else:
151
+ match_info = "_No close match found in the dataset._"
152
 
153
+ # ── Execute SQL ────────────────────────────────────────────────────────
154
+ df, err = run_query(sql)
155
+ if err:
156
+ results = pd.DataFrame([{"error": err}])
157
+ elif df is not None and not df.empty:
158
+ results = df
159
+ else:
160
+ results = pd.DataFrame([{"info": "Query returned no rows."}])
161
 
162
+ return sql, match_info, results
 
163
 
164
 
165
  # ── Gradio UI ──────────────────────────────────────────────────────────────
166
  with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
167
+ gr.Markdown("# 🌍 Chichewa Text-to-SQL\nEnter a question in Chichewa or English to generate SQL, match it against the dataset, and run it on the database.")
168
 
169
  with gr.Row():
170
  question_box = gr.Textbox(
 
172
  placeholder="Ndi boma liti komwe anakolola chimanga chambiri?",
173
  lines=3,
174
  )
175
+ language_box = gr.Radio(["ny", "en"], value="ny", label="Language")
 
 
 
 
176
 
177
+ submit_btn = gr.Button("Generate SQL & Run", variant="primary")
178
+
179
+ sql_output = gr.Code(label="Generated SQL", language="sql")
180
+ match_output = gr.Markdown(label="Dataset Match")
181
+ result_output = gr.Dataframe(label="Query Results", wrap=True)
182
 
183
  submit_btn.click(
184
  fn=generate_sql,
185
  inputs=[question_box, language_box],
186
+ outputs=[sql_output, match_output, result_output],
187
  )
188
 
189
  gr.Examples(
requirements.txt CHANGED
@@ -4,4 +4,5 @@ torch>=2.4.0
4
  accelerate>=0.34.0
5
  safetensors>=0.4.0
6
  spaces>=0.30.0
7
- bitsandbytes>=0.46.1
 
 
4
  accelerate>=0.34.0
5
  safetensors>=0.4.0
6
  spaces>=0.30.0
7
+ bitsandbytes>=0.46.1
8
+ pandas>=2.0.0