kirankotha commited on
Commit
3ff1e33
·
verified ·
1 Parent(s): 3c09a33

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, torch, gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ model_id = "kirankotha/mistral7b-sql-model"
5
+
6
+ tok = AutoTokenizer.from_pretrained(model_id)
7
+ mdl = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
8
+
9
+ if tok.pad_token is None:
10
+ tok.pad_token = tok.eos_token
11
+
12
+ def gen_sql(question, context):
13
+ prompt = f"You are a text-to-SQL model.\n### Input:\n{question}\n### Context:\n{context}\n### Response:\n"
14
+ ids = tok(prompt, return_tensors="pt").to(mdl.device)
15
+ with torch.no_grad():
16
+ out = mdl.generate(
17
+ **ids,
18
+ max_new_tokens=128,
19
+ do_sample=False,
20
+ pad_token_id=tok.pad_token_id
21
+ )
22
+ text = tok.decode(out[0], skip_special_tokens=True)
23
+ m = re.findall(r"```(.*?)```", text, flags=re.DOTALL)
24
+ return m[-1].strip() if m else text.strip()
25
+
26
+ demo = gr.Interface(
27
+ fn=gen_sql,
28
+ inputs=["text","text"],
29
+ outputs="text",
30
+ title="Mistral-7B Text-to-SQL Demo",
31
+ description="Fine-tuned Mistral-7B model for text-to-SQL generation.",
32
+ examples=[["Which product has the highest price?",
33
+ "CREATE TABLE products (id INTEGER, name TEXT, price REAL)"]]
34
+ )
35
+
36
+ demo.launch()