Chaitanya182004 commited on
Commit
dabf498
·
verified ·
1 Parent(s): d469104

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+
5
+ MODEL_NAME = "Chaitanya182004/nl2sql-model"
6
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ model = model.to(device)
11
+
12
+ def generate_sql(question, context):
13
+ input_text = f"{question} | {context}"
14
+ inputs = tokenizer(input_text, return_tensors='pt',
15
+ max_length=512, truncation=True).to(device)
16
+ outputs = model.generate(**inputs, max_new_tokens=128,
17
+ num_beams=4, early_stopping=True)
18
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
+ return sql
20
+
21
+ demo = gr.Interface(
22
+ fn=generate_sql,
23
+ inputs=[
24
+ gr.Textbox(label="Question"),
25
+ gr.Textbox(label="Context")
26
+ ],
27
+ outputs=gr.Textbox(label="SQL"),
28
+ title="NL2SQL API"
29
+ )
30
+
31
+ demo.launch()