| | """The program includes several functions: setting a random seed, |
| | loading data from a JSON file, batching data, and extracting answers from generated text. |
| | """ |
| |
|
| | import random |
| | import numpy as np |
| | import torch |
| | import json |
| | import re |
| | def set_random_seed(seed: int): |
| | """ |
| | Set the random seed for `random`, `numpy`, `torch`, `torch.cuda`. |
| | |
| | Parameters |
| | ------------ |
| | seed : int |
| | The default seed. |
| | |
| | """ |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | def load_data(file_name: str): |
| | """ |
| | Load data with file name. |
| | |
| | Parameters |
| | ------------ |
| | file_name : str. |
| | The dataset file name. |
| | |
| | Returns |
| | ------------ |
| | inputs : list. |
| | The input texts of the dataset. |
| | outputs : list. |
| | The output texts file datasets. |
| | len : int. |
| | The length of the dataset. |
| | """ |
| | inputs = [] |
| | outputs = [] |
| | type = "" |
| | with open(file_name, encoding='utf-8') as f: |
| | json_data = json.load(f) |
| | type = json_data["type"] |
| | for line in json_data["instances"]: |
| | inputs.append(line["input"]) |
| | outputs.append(line["output"]) |
| |
|
| | print(f"load dataset {file_name} success.\n") |
| | print(f"Type : {type}, datasize : {len(outputs)}") |
| |
|
| | return inputs, outputs, len(outputs) |
| |
|
| | def batchlize(examples: list, batch_size: int, random_shuffle: bool): |
| | """ |
| | Convert examples to a dataloader. |
| | |
| | Parameters |
| | ------------ |
| | examples : list. |
| | Data list. |
| | batch_size : int. |
| | |
| | random_shuffle : bool |
| | If true, the dataloader shuffle the training data. |
| | |
| | Returns |
| | ------------ |
| | dataloader: |
| | Dataloader with batch generator. |
| | """ |
| | size = 0 |
| | dataloader = [] |
| | length = len(examples) |
| | if (random_shuffle): |
| | random.shuffle(examples) |
| | while size < length: |
| | if length - size > batch_size: |
| | dataloader.append(examples[size : size+batch_size]) |
| | size += batch_size |
| | else: |
| | dataloader.append(examples[size : size+(length-size)]) |
| | size += (length - size) |
| | return dataloader |
| |
|
| |
|
| |
|
| | def answer_extraction(response, answer_type=None): |
| |
|
| | """ |
| | Use this funtion to extract answers from generated text |
| | |
| | Parameters |
| | ------------ |
| | args : |
| | Arguments. |
| | response : str |
| | plain string response. |
| | |
| | |
| | Returns |
| | ------------ |
| | answer: |
| | Decoded answer (such as A, B, C, D, E for mutiple-choice QA). |
| | """ |
| |
|
| | |
| | temp = response |
| | if answer_type in ("gsm8k", "svamp", "asdiv", "addsub", "singleeq", "multiarith", "math"): |
| | temp = temp.replace(",", "") |
| | temp = [s for s in re.findall(r'-?\d+\.?\d*', temp)] |
| | elif answer_type in ("aqua", "csqa", "multiple_choice"): |
| | temp = re.findall(r'A|B|C|D|E', temp) |
| | elif answer_type in ("strategyqa", "coin_flip"): |
| | temp = temp.lower() |
| | temp = re.sub("\"|\'|\n|\.|\s|\:|\,"," ", temp) |
| | temp = temp.split(" ") |
| | temp = [i for i in temp if i in ("yes", "no")] |
| | elif answer_type in ("last_letters"): |
| | temp = re.sub("\"|\'|\n|\.|\s","", temp) |
| | temp = [temp] |
| | elif answer_type in ("pubmedqa", "binary_choice"): |
| | |
| | |
| | |
| | pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)" |
| | sttr = re.search(pattern, temp) |
| | if sttr is not None: |
| | mid_answer = sttr.group(0) |
| | mid_answer = mid_answer.split(":")[-1].strip() |
| | answer = mid_answer.lower() |
| | else: |
| | pattern = "(yes|Yes|YES|no|No|NO|maybe|Maybe|MAYBE)(\.|\s)" |
| | sttr = re.search(pattern, temp) |
| | if sttr is not None: |
| | answer = sttr.group(0)[:-1].lower() |
| | else: |
| | answer = "N/A" |
| | return answer |
| | elif answer_type == "medmcqa": |
| | |
| | |
| | |
| | pattern = "(answer|Answer|ANSWER|output|Output|OUTPUT|A): \(*(A|B|C|D|a|b|c|d)" |
| | sttr = re.search(pattern, temp) |
| | if sttr is not None: |
| | mid_answer = sttr.group(0) |
| | answer = mid_answer[-1].lower() |
| | else: |
| | pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)" |
| | sttr = re.search(pattern, temp) |
| | if sttr is not None: |
| | if '(' in sttr.group(0): |
| | answer = sttr.group(0)[1].lower() |
| | else: |
| | answer = sttr.group(0)[0].lower() |
| | else: |
| | answer = "N/A" |
| | return answer |
| |
|
| | elif answer_type == "usmle": |
| | |
| | |
| | |
| | pattern = "(Answer|Output|A): \(*(A|B|C|D|a|b|c|d)" |
| | sttr = re.search(pattern, temp) |
| | if sttr is not None: |
| | mid_answer = sttr.group(0) |
| | answer = mid_answer[-1].lower() |
| | else: |
| | pattern = "\(*(A|B|C|D|a|b|c|d)\)*(\.|\s)" |
| | sttr = re.search(pattern, temp) |
| | if sttr is not None: |
| | if '(' in sttr.group(0): |
| | answer = sttr.group(0)[1].lower() |
| | else: |
| | answer = sttr.group(0)[0].lower() |
| | else: |
| | answer = "N/A" |
| | return answer |
| | elif answer_type == "text": |
| | return response |
| | else: |
| | raise NotImplementedError(f"Unsupported answer type: {answer_type}") |
| |
|
| | if len(temp) != 0: |
| | answer = temp[-1] |
| | |
| | |
| | if answer != "": |
| | if answer[-1] == ".": |
| | answer = answer[:-1] |
| |
|
| | |
| | if answer_type in ("gsm8k", "svamp"): |
| | try: |
| | answer = str(round(float(answer))) |
| | except: |
| | answer = "" |
| | elif answer_type in ("last_letters"): |
| | try: |
| | answer = answer[-args.concat_length:] |
| | except: |
| | answer = "" |
| | else: |
| | answer = "" |
| | return answer |
| |
|