johneze commited on
Commit
107f585
Β·
verified Β·
1 Parent(s): 09b455a

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +107 -6
  2. requirements.txt +1 -0
app.py CHANGED
@@ -14,13 +14,114 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
 
15
  MODEL_ID = "johneze/Llama-3.1-8B-Instruct-chichewa-text2sql"
16
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
18
- model = AutoModelForCausalLM.from_pretrained(
19
- MODEL_ID,
20
- torch_dtype=torch.bfloat16,
21
- device_map="auto",
22
- )
23
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def extract_sql(text: str) -> str:
 
14
 
15
  MODEL_ID = "johneze/Llama-3.1-8B-Instruct-chichewa-text2sql"
16
 
17
+ # Tokenizer is tiny β€” safe to load at startup without a GPU
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
19
+
20
+ # Model is loaded lazily on the FIRST call to generate_sql, where the GPU
21
+ # context (@spaces.GPU) is already active and CUDA is available.
22
+ _pipe = None
23
+
24
+
25
+ def extract_sql(text: str) -> str:
26
+ match = re.search(r"(?is)select\s.+", text)
27
+ if not match:
28
+ return text.strip()
29
+ sql = match.group(0)
30
+ for sep in [";", "\n"]:
31
+ if sep in sql:
32
+ sql = sql.split(sep)[0]
33
+ return sql.strip() + ";"
34
+
35
+
36
+ @spaces.GPU(duration=120)
37
+ def generate_sql(question: str, language: str = "ny") -> str:
38
+ """
39
+ Generate SQL from a Chichewa or English question.
40
+ language: 'ny' for Chichewa, 'en' for English.
41
+ Returns a SQL SELECT statement.
42
+ """
43
+ global _pipe
44
+ if _pipe is None:
45
+ # First call: GPU is now available β€” load the 4-bit quantized model
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ MODEL_ID,
48
+ dtype=torch.bfloat16,
49
+ device_map="auto",
50
+ )
51
+ _pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
52
+
53
+ lang_name = "Chichewa" if language == "ny" else "English"
54
+
55
+ messages = [
56
+ {
57
+ "role": "system",
58
+ "content": (
59
+ "You are an expert Text-to-SQL model for a SQLite database "
60
+ "with the following tables: production, population, food_insecurity, "
61
+ "commodity_prices, mse_daily. "
62
+ "Given a natural language question, generate ONE valid SQL SELECT query. "
63
+ "Return ONLY the SQL query, no explanation."
64
+ ),
65
+ },
66
+ {
67
+ "role": "user",
68
+ "content": f"Language: {lang_name}\nQuestion: {question}",
69
+ },
70
+ ]
71
+
72
+ prompt = tokenizer.apply_chat_template(
73
+ messages, tokenize=False, add_generation_prompt=True
74
+ )
75
+
76
+ out = _pipe(
77
+ prompt,
78
+ max_new_tokens=128,
79
+ do_sample=False,
80
+ pad_token_id=tokenizer.eos_token_id,
81
+ )[0]["generated_text"]
82
+
83
+ generated = out[len(prompt):] if out.startswith(prompt) else out
84
+ return extract_sql(generated)
85
+
86
+
87
+ # ── Gradio UI ──────────────────────────────────────────────────────────────
88
+ with gr.Blocks(title="Chichewa Text-to-SQL") as demo:
89
+ gr.Markdown("# Chichewa Text-to-SQL\nEnter a question in Chichewa or English to generate SQL.")
90
+
91
+ with gr.Row():
92
+ question_box = gr.Textbox(
93
+ label="Question",
94
+ placeholder="Ndi boma liti komwe anakolola chimanga chambiri?",
95
+ lines=3,
96
+ )
97
+ language_box = gr.Radio(
98
+ ["ny", "en"],
99
+ value="ny",
100
+ label="Language",
101
+ )
102
+
103
+ submit_btn = gr.Button("Generate SQL", variant="primary")
104
+ sql_output = gr.Code(label="Generated SQL", language="sql")
105
+
106
+ submit_btn.click(
107
+ fn=generate_sql,
108
+ inputs=[question_box, language_box],
109
+ outputs=sql_output,
110
+ )
111
+
112
+ gr.Examples(
113
+ examples=[
114
+ ["Ndi boma liti komwe anakolola chimanga chambiri?", "ny"],
115
+ ["Which district produced the most Maize?", "en"],
116
+ ["Ndi anthu angati ku Lilongwe?", "ny"],
117
+ ["What is the food insecurity level in Nsanje?", "en"],
118
+ ],
119
+ inputs=[question_box, language_box],
120
+ )
121
+
122
+
123
+ if __name__ == "__main__":
124
+ demo.launch()
125
 
126
 
127
  def extract_sql(text: str) -> str:
requirements.txt CHANGED
@@ -4,3 +4,4 @@ torch>=2.4.0
4
  accelerate>=0.34.0
5
  safetensors>=0.4.0
6
  spaces>=0.30.0
 
 
4
  accelerate>=0.34.0
5
  safetensors>=0.4.0
6
  spaces>=0.30.0
7
+ bitsandbytes>=0.46.1