| Gretel's baseline text2table was fine-tuned on togethercomputer's RedPajama-INCITE-instruct-3B-v1 model for 100 epochs on 8A100 80GB gpu's. The fine-tuning used ~2k training samples (text and table pairs) that were generated using OpenAI. | |
| ## Data Formatting | |
| ```python | |
| INSTRUCTION_KEY = "### Instruction: Given the following prompt, generate a table" | |
| RESPONSE_KEY = "### Response:" | |
| INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
| PROMPT_FOR_GENERATION_FORMAT = """{intro} | |
| {instruction_key} | |
| {prompt_to_generate_table} | |
| {response_key} | |
| {table} | |
| """.format( | |
| intro=INTRO_BLURB, | |
| instruction_key=INSTRUCTION_KEY, | |
| prompt_to_generate_table"{PROMPT}", | |
| response_key=RESPONSE_KEY, | |
| table="{TABLE}" | |
| ) | |
| ``` | |
| ## For generation purposes: | |
| ```python | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained('togethercomputer/RedPajama-INCITE-Instruct-3B-v1', padding_side="right") | |
| model = AutoModelForCausalLM.from_pretrained('gretelai/text2table').to('cuda') | |
| model.eval() | |
| INSTRUCTION_KEY = "### Instruction: Given the following prompt, generate a table." | |
| RESPONSE_KEY = "### Response:" | |
| INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
| PROMPT_FOR_GENERATION_FORMAT = """{intro} | |
| {instruction_key} | |
| {prompt_to_generate_table} | |
| {response_key} | |
| """.format( | |
| intro=INTRO_BLURB, | |
| instruction_key=INSTRUCTION_KEY, | |
| prompt_to_generate_table="{prompt_to_generate_table}", | |
| response_key=RESPONSE_KEY, | |
| ) | |
| PROMPT = "Create a dataset with four columns: patient, sex, agegrp, bp_before and bp_after. The patient column is a numerical identifier, sex is the gender of the patient, agegrp is the age group of the patient, bp_before is the blood pressure (in mmHg) before a certain treatment, and bp_after is the blood pressure (in mmHg) after a certain treatment." | |
| inputs = PROMPT_FOR_GENERATION_FORMAT.format(prompt_to_generate_table=PROMPT) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| input = tokenizer(inputs, return_tensors="pt").to('cuda') | |
| input_ids = input['input_ids'] | |
| outputs = model.generate(**input, max_length = 1024) | |
| table = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| ``` | |
| ## Output | |
| ```python | |
| PROMPT = "Create a dataset with four columns: patient, sex, agegrp, bp_before and bp_after. The patient column is a numerical identifier, sex is the gender of the patient, agegrp is the age group of the patient, bp_before is the blood pressure (in mmHg) before a certain treatment, and bp_after is the blood pressure (in mmHg) after a certain treatment." | |
| MODEL GENERATION -> | |
| Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
| Instruction: Given the following prompt, generate a table. Each column should have random values. | |
| Create a dataset with four columns: patient, sex, agegrp, bp_before and bp_after. The patient column is a numerical identifier, sex is the gender of the patient, agegrp is the age group of the patient, bp_before is the blood pressure (in mmHg) before a certain treatment, and bp_after is the blood pressure (in mmHg) after a certain treatment. | |
| Response: | |
| patient,sex,agegrp,bp_before,bp_after | |
| 1.0,F,45.0,183.0,124.0,234.0 | |
| 2.0,F,60.0,183.0,124.0,183.0 | |
| 3.0,F,70.0,179.0,117.0,183.0 | |
| 4.0,M,30.0,141.0,136.0,161.0 | |
| 5.0,M,70.0,147.0,129.0,157.0 | |
| 6.0,M,40.0,140.0,136.0,156.0 | |
| 7.0,M,60.0,140.0,116.0,157.0 | |
| 8.0,M,70.0,144.0,131.0,161.0 | |
| 9.0,M,60.0,142.0,119.0,157.0 | |
| 10.0,M,70.0,147.0,132.0,167.0 | |
| 11.0,M,60.0,147.0,136.0,166.0 | |
| 12.0,M,70.0,150.0,132.0,172.0 | |
| 13.0,M,60.0,149.0,137.0,162.0 | |
| 14.0,M,70.0,156.0,124.0,157.0 | |
| 15.0,M,60.0,156.0,181.0,157.0 | |
| 16.0,M,70.0,156.0,131.0,158.0 | |
| ``` |