| --- |
| language: |
| - EN |
| datasets: |
| - wikisql |
| widget: |
| - example-1: "English to SQL: Show me the average age of of wines in Italy by provinces" |
| - example-2: "English to SQL: What is the current series where the new series began in June 2011?" |
| --- |
| #import transformers |
|
|
| from transformers import ( |
| T5ForConditionalGeneration, |
| T5Tokenizer, |
| ) |
| |
| #load model |
|
|
| model = T5ForConditionalGeneration.from_pretrained('dsivakumar/text2sql') |
| tokenizer = T5Tokenizer.from_pretrained('dsivakumar/text2sql') |
|
|
| #predict function |
|
|
| def get_sql(query,tokenizer,model): |
| source_text= "English to SQL: "+query |
| source_text = ' '.join(source_text.split()) |
| source = tokenizer.batch_encode_plus([source_text],max_length= 128, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt') |
| source_ids = source['input_ids'] #.squeeze() |
| source_mask = source['attention_mask']#.squeeze() |
| generated_ids = model.generate( |
| input_ids = source_ids.to(dtype=torch.long), |
| attention_mask = source_mask.to(dtype=torch.long), |
| max_length=150, |
| num_beams=2, |
| repetition_penalty=2.5, |
| length_penalty=1.0, |
| early_stopping=True |
| ) |
| preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids] |
| return preds |
| |
| #test |
| |
| query="Show me the average age of of wines in Italy by provinces" |
| sql = get_sql(query,tokenizer,model) |
| print(sql) |
| |
| |