Chaitanya182004 commited on
Commit
9a6bae9
·
verified ·
1 Parent(s): 0a90fb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -20
app.py CHANGED
@@ -2,30 +2,39 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
- MODEL_NAME = "Chaitanya182004/nl2sql-model"
6
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
7
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
8
 
9
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
- model = model.to(device)
 
 
 
 
11
 
12
- def generate_sql(question, context):
13
  input_text = f"{question} | {context}"
14
- inputs = tokenizer(input_text, return_tensors='pt',
15
- max_length=512, truncation=True).to(device)
16
- outputs = model.generate(**inputs, max_new_tokens=128,
17
- num_beams=4, early_stopping=True)
 
 
 
 
 
 
 
 
18
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
  return sql
20
 
21
- demo = gr.Interface(
22
- fn=generate_sql,
23
- inputs=[
24
- gr.Textbox(label="Question"),
25
- gr.Textbox(label="Context")
26
- ],
27
- outputs=gr.Textbox(label="SQL"),
28
- title="NL2SQL API"
29
- )
30
 
31
- demo.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
+ MODEL_NAME = "gaussalgo/T5-LM-Large-text2sql-spider"
 
 
6
 
7
+ print("Loading model...")
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+ model = model.to(device)
12
+ print(f"Model ready on {device}")
13
 
14
+ def generate_sql(question: str, context: str) -> str:
15
  input_text = f"{question} | {context}"
16
+ inputs = tokenizer(
17
+ input_text,
18
+ return_tensors='pt',
19
+ max_length=512,
20
+ truncation=True
21
+ ).to(device)
22
+ outputs = model.generate(
23
+ **inputs,
24
+ max_new_tokens=128,
25
+ num_beams=4,
26
+ early_stopping=True
27
+ )
28
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
  return sql
30
 
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("# NL2SQL API")
33
+ with gr.Row():
34
+ question = gr.Textbox(label="Question")
35
+ context = gr.Textbox(label="Context")
36
+ output = gr.Textbox(label="SQL")
37
+ btn = gr.Button("Submit")
38
+ btn.click(fn=generate_sql, inputs=[question, context], outputs=output)
 
39
 
40
+ demo.launch(server_name="0.0.0.0", server_port=7860)