upadhyay commited on
Commit
de4f232
·
verified ·
1 Parent(s): 25a3d75

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +29 -0
README.md CHANGED
@@ -12,3 +12,32 @@ Hyperparameters:
12
  "max target length" = 256
13
  "model name" : "facebook/bart-large-cnn"
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  "max target length" = 256
13
  "model name" : "facebook/bart-large-cnn"
14
 
15
+
16
+
17
+ use code :
18
+
19
+
20
+ from typing import List
21
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("upadhyay/sql")
24
+ model = AutoModelForSeq2SeqLM.from_pretrained("upadhyay/sql")
25
+
26
+ def prepare_input(question: str, table: List[str]):
27
+ table_prefix = "table:"
28
+ question_prefix = "question:"
29
+ join_table = ",".join(table)
30
+ inputs = f"{question_prefix} {question} {table_prefix} {join_table}"
31
+ input_ids = tokenizer(inputs, max_length=700, return_tensors="pt").input_ids
32
+ return input_ids
33
+
34
+ def inference(question: str, table: List[str]) -> str:
35
+ input_data = prepare_input(question=question, table=table)
36
+ input_data = input_data.to(model.device)
37
+ outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700)
38
+ result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
39
+ return result
40
+
41
+ print(inference(question="what is salary?", table=["id", "name", "age"]))
42
+
43
+