johneze commited on
Commit
c1bc8ec
Β·
verified Β·
1 Parent(s): 8cf79dd

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +77 -53
app.py CHANGED
@@ -10,6 +10,7 @@ import json
10
  import re
11
  import sqlite3
12
  import difflib
 
13
  from pathlib import Path
14
 
15
  import spaces
@@ -95,7 +96,7 @@ def run_query(sql: str):
95
  conn.close()
96
 
97
 
98
- # ── Model (pre-download weights at startup, load into GPU on first call) ───
99
  print("Downloading model weights to cache ...")
100
  _model_cache = snapshot_download(repo_id=MODEL_ID)
101
  print(f"Model cached at: {_model_cache}")
@@ -107,59 +108,82 @@ _pipe = None
107
  # ── Main function ──────────────────────────────────────────────────────────
108
  @spaces.GPU(duration=300)
109
  def generate_sql(question: str, language: str = "ny"):
110
- """Returns (sql, match_info_markdown, results_dataframe)."""
111
- global _pipe
112
- if _pipe is None:
113
- model = AutoModelForCausalLM.from_pretrained(
114
- _model_cache,
115
- dtype=torch.bfloat16,
116
- device_map="auto",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
- _pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
119
-
120
- lang_name = "Chichewa" if language == "ny" else "English"
121
- messages = [
122
- {
123
- "role": "system",
124
- "content": (
125
- "You are an expert Text-to-SQL model for a SQLite database "
126
- "with tables: production, population, food_insecurity, "
127
- "commodity_prices, mse_daily. "
128
- "Generate ONE valid SQL SELECT query. Return ONLY the SQL, no explanation."
129
- ),
130
- },
131
- {"role": "user", "content": f"Language: {lang_name}\nQuestion: {question}"},
132
- ]
133
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
134
- out = _pipe(prompt, max_new_tokens=128, do_sample=False,
135
- pad_token_id=tokenizer.eos_token_id)[0]["generated_text"]
136
- generated = out[len(prompt):] if out.startswith(prompt) else out
137
- sql = extract_sql(generated)
138
-
139
- # Dataset match
140
- example, score, mode = find_match(question, language)
141
- if example:
142
- match_info = (
143
- f"**Match:** {mode} (score: {score})\n\n"
144
- f"**ny:** {example.get('question_ny', '')}\n\n"
145
- f"**en:** {example.get('question_en', '')}\n\n"
146
- f"**Dataset SQL:** `{example.get('sql_statement', '')}`\n\n"
147
- f"**Table:** {example.get('table', '')} | "
148
- f"**Difficulty:** {example.get('difficulty_level', '')}"
 
 
 
 
 
 
 
 
 
 
149
  )
150
- else:
151
- match_info = "_No close match found in the dataset._"
152
-
153
- # Execute SQL
154
- df, err = run_query(sql)
155
- if err:
156
- results = pd.DataFrame([{"error": err}])
157
- elif df is not None and not df.empty:
158
- results = df
159
- else:
160
- results = pd.DataFrame([{"info": "Query returned no rows."}])
161
-
162
- return sql, match_info, results
163
 
164
 
165
  # ── Gradio UI ──────────────────────────────────────────────────────────────
@@ -180,7 +204,7 @@ with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
180
 
181
  submit_btn = gr.Button("Generate SQL & Run", variant="primary")
182
 
183
- sql_output = gr.Code(label="Generated SQL", language="sql")
184
  match_output = gr.Markdown()
185
  result_output = gr.Dataframe(label="Query Results", wrap=True)
186
 
 
10
  import re
11
  import sqlite3
12
  import difflib
13
+ import traceback
14
  from pathlib import Path
15
 
16
  import spaces
 
96
  conn.close()
97
 
98
 
99
+ # ── Model (pre-download at startup, load into GPU lazily on first call) ────
100
  print("Downloading model weights to cache ...")
101
  _model_cache = snapshot_download(repo_id=MODEL_ID)
102
  print(f"Model cached at: {_model_cache}")
 
108
  # ── Main function ──────────────────────────────────────────────────────────
109
  @spaces.GPU(duration=300)
110
  def generate_sql(question: str, language: str = "ny"):
111
+ """Returns (sql_str, match_info_markdown, results_dataframe)."""
112
+ # Always return 3 values even on error so Gradio never shows generic "Error"
113
+ empty_df = pd.DataFrame()
114
+ try:
115
+ global _pipe
116
+ if _pipe is None:
117
+ model = AutoModelForCausalLM.from_pretrained(
118
+ _model_cache,
119
+ torch_dtype=torch.bfloat16,
120
+ device_map="cuda",
121
+ )
122
+ _pipe = pipeline(
123
+ "text-generation",
124
+ model=model,
125
+ tokenizer=tokenizer,
126
+ device_map="cuda",
127
+ )
128
+
129
+ lang_name = "Chichewa" if language == "ny" else "English"
130
+ messages = [
131
+ {
132
+ "role": "system",
133
+ "content": (
134
+ "You are an expert Text-to-SQL model for a SQLite database "
135
+ "with tables: production, population, food_insecurity, "
136
+ "commodity_prices, mse_daily. "
137
+ "Generate ONE valid SQL SELECT query. Return ONLY the SQL, no explanation."
138
+ ),
139
+ },
140
+ {"role": "user", "content": f"Language: {lang_name}\nQuestion: {question}"},
141
+ ]
142
+ prompt = tokenizer.apply_chat_template(
143
+ messages, tokenize=False, add_generation_prompt=True
144
  )
145
+ out = _pipe(
146
+ prompt,
147
+ max_new_tokens=128,
148
+ do_sample=False,
149
+ pad_token_id=tokenizer.eos_token_id,
150
+ )[0]["generated_text"]
151
+ generated = out[len(prompt):] if out.startswith(prompt) else out
152
+ sql = extract_sql(generated)
153
+
154
+ # Dataset match
155
+ example, score, mode = find_match(question, language)
156
+ if example:
157
+ match_info = (
158
+ f"**Match:** {mode} (score: {score})\n\n"
159
+ f"**ny:** {example.get('question_ny', '')}\n\n"
160
+ f"**en:** {example.get('question_en', '')}\n\n"
161
+ f"**Dataset SQL:** `{example.get('sql_statement', '')}`\n\n"
162
+ f"**Table:** {example.get('table', '')} | "
163
+ f"**Difficulty:** {example.get('difficulty_level', '')}"
164
+ )
165
+ else:
166
+ match_info = "_No close match found in the dataset._"
167
+
168
+ # Execute SQL
169
+ df, err = run_query(sql)
170
+ if err:
171
+ results = pd.DataFrame([{"error": err}])
172
+ elif df is not None and not df.empty:
173
+ results = df
174
+ else:
175
+ results = pd.DataFrame([{"info": "Query returned no rows."}])
176
+
177
+ return sql, match_info, results
178
+
179
+ except Exception:
180
+ err_msg = traceback.format_exc()
181
+ print(err_msg)
182
+ return (
183
+ f"-- ERROR --\n{err_msg}",
184
+ f"**Error during generation:**\n```\n{err_msg}\n```",
185
+ pd.DataFrame([{"error": err_msg}]),
186
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
  # ── Gradio UI ──────────────────────────────────────────────────────────────
 
204
 
205
  submit_btn = gr.Button("Generate SQL & Run", variant="primary")
206
 
207
+ sql_output = gr.Textbox(label="Generated SQL", lines=3)
208
  match_output = gr.Markdown()
209
  result_output = gr.Dataframe(label="Query Results", wrap=True)
210