mcps5601 commited on
Commit
e77bcc6
·
1 Parent(s): e34188e

Add application files

Browse files
Files changed (6) hide show
  1. app.py +127 -0
  2. args.json +33 -0
  3. class_names.pkl +3 -0
  4. prompt_dataset.py +169 -0
  5. prompt_model_factory.py +88 -0
  6. utils.py +138 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prompt_model_factory import BertForPromptFinetuning
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ DataCollatorWithPadding,
5
+ TrainingArguments,
6
+ Trainer,
7
+ EvalPrediction,
8
+ )
9
+
10
+ # from prompt_tuning import compute_metrics
11
+ import torch
12
+ import pickle
13
+ import numpy as np
14
+ from prompt_dataset import InferenceDataset
15
+ import gradio as gr
16
+ from utils import load_params, get_label_words, pred_by_threshold
17
+
18
+
19
+ def compute_metrics(
20
+ threshold=None,
21
+ classes=None,
22
+ p_tuning=False,
23
+ ):
24
+ def compute_metric_threshold(eval_pred: EvalPrediction):
25
+ return pred_by_threshold(
26
+ t=threshold,
27
+ y_true=eval_pred.label_ids,
28
+ similarities=eval_pred.predictions
29
+ if p_tuning
30
+ else torch.sigmoid(torch.tensor(eval_pred.predictions)),
31
+ classes=classes,
32
+ )
33
+
34
+ return compute_metric_threshold
35
+
36
+
37
+ def greet(input_text):
38
+ prompt_FT = True
39
+ file = open(f"class_names.pkl", "rb")
40
+ classes = pickle.load(file)
41
+ class_names = list(classes.keys())
42
+ id_to_class = {i: class_names[i] for i in range(len(class_names))}
43
+
44
+ device = (
45
+ torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
46
+ )
47
+ args = load_params("args.json")
48
+ model_path = f"IKMLab/MPTR"
49
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
50
+
51
+ if prompt_FT:
52
+ # Prompt tuning
53
+ label_words = get_label_words(list(classes.keys()), args.use_multi_label_words)
54
+
55
+ if args.use_multi_label_words:
56
+ label_word_ids = []
57
+ for l in label_words:
58
+ one_label_ids = [tokenizer.convert_tokens_to_ids(word) for word in l]
59
+ label_word_ids.append(one_label_ids)
60
+ else:
61
+ label_word_ids = (
62
+ torch.tensor([tokenizer.convert_tokens_to_ids(l) for l in label_words])
63
+ .long()
64
+ .to(device)
65
+ )
66
+ model = BertForPromptFinetuning.from_pretrained(
67
+ model_path,
68
+ use_multi_label_words=args.use_multi_label_words,
69
+ )
70
+ model.label_word_ids = label_word_ids
71
+
72
+ result_path = f"results/predict"
73
+
74
+ training_args = TrainingArguments(
75
+ output_dir=result_path,
76
+ learning_rate=args.lr,
77
+ per_device_train_batch_size=args.batch_size,
78
+ per_device_eval_batch_size=1,
79
+ num_train_epochs=args.num_epochs,
80
+ weight_decay=0.01,
81
+ warmup_ratio=args.warmup_ratio,
82
+ seed=args.seed,
83
+ evaluation_strategy="steps",
84
+ logging_steps=100, # same as eval_steps
85
+ save_strategy="steps",
86
+ save_steps=100,
87
+ save_total_limit=1,
88
+ load_best_model_at_end=True,
89
+ metric_for_best_model=f"eval_{args.best_metric}",
90
+ )
91
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
92
+ trainer = Trainer(
93
+ model=model,
94
+ args=training_args,
95
+ train_dataset=None,
96
+ eval_dataset=None,
97
+ tokenizer=tokenizer,
98
+ data_collator=data_collator,
99
+ compute_metrics=compute_metrics(
100
+ threshold=args.t,
101
+ classes=classes,
102
+ p_tuning=prompt_FT,
103
+ ),
104
+ )
105
+
106
+ testset = InferenceDataset(
107
+ input_text,
108
+ tokenizer,
109
+ args.max_seq_len,
110
+ template=args.template,
111
+ prompt=args.prompt,
112
+ )
113
+ result = trainer.predict(testset)
114
+ predictions = (result.predictions[0] >= args.t) * 1
115
+ positive_idx = np.where(predictions == 1)[0]
116
+ if len(positive_idx) == 0:
117
+ return "No positive findings."
118
+
119
+ return [id_to_class[i] for i in positive_idx]
120
+
121
+
122
+ # test = "Two small 0.6-cm and 1.4-cm densely packed lipiodol puddles in S7 without identifiable viable tumor, suggestive of good response to previous TACE without viability."
123
+ # result = greet(test)
124
+
125
+
126
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
127
+ iface.launch()
args.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "template": "*cls**sent_0*_[PROMPT]*mask*.*sep+*",
3
+ "prompt": "The_report_is_related_to",
4
+ "num_labels": 7,
5
+ "report_filter": "full",
6
+ "max_seq_len": 512,
7
+ "batch_size": 2,
8
+ "model_name": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
9
+ "gpu_id": "1",
10
+ "cls_mode": "multi_label",
11
+ "k": null,
12
+ "t": 0.2,
13
+ "db_date": "20230606_new",
14
+ "best_metric": "loss",
15
+ "exp_tag": "",
16
+ "num_exps": 5,
17
+ "do_train": true,
18
+ "num_epochs": 30,
19
+ "do_predict": true,
20
+ "seed": 42,
21
+ "lr": 3e-05,
22
+ "warmup_ratio": 0.0,
23
+ "data_type": "train_32",
24
+ "save_conf_matrix": false,
25
+ "use_multi_label_words": true,
26
+ "allow_multi_label_tokens": false,
27
+ "verbalizer_name": "",
28
+ "enable_emboliz": false,
29
+ "enable_rfa": false,
30
+ "enable_tace": false,
31
+ "enable_lobectomy": false,
32
+ "save_checkpoints": true
33
+ }
class_names.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d8088c6c9d790808303e0a3e9b122c9d9103a4e2c30694f4d5e3351d2c25872
3
+ size 110
prompt_dataset.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+
4
+
5
+ def get_prompt_length(tokenizer, prompt):
6
+ return len(tokenizer.encode(prompt))
7
+
8
+
9
+ def tokenize_multipart_input(
10
+ tokenizer,
11
+ input_text_list: list,
12
+ max_seq_len: int,
13
+ template=None,
14
+ prompt=None,
15
+ ):
16
+ """This function is an adaptation of the `tokenize_multipart_input` found in princeton-nlp's repository
17
+ at https://github.com/princeton-nlp/LM-BFF/blob/main/src/dataset.py.
18
+
19
+ Modifications include:
20
+ - Extension of automatic prompt generation for multi-label classification.
21
+ - Removal of parameters like `first_sent_limit`, `other_sent_limit`, `gpt3`, `truncate_head`, and `support_labels`.
22
+ - Optimization of the code flow.
23
+
24
+ Args:
25
+ tokenizer: a pre-trained tokenizer from Hugging Face Transformers
26
+ input_text_list (list): documents ready for tokenization.
27
+ max_seq_len (int): max sequence length after adding the prompt along with special tokens from BERT.
28
+ template (str, optional): placeholder for the prompt.
29
+ prompt (str, optional): the prompt we use for input text.
30
+ """
31
+
32
+ def enc(text):
33
+ return tokenizer.encode(text, add_special_tokens=False)
34
+
35
+ input_ids = []
36
+ attention_mask = []
37
+ token_type_ids = [] # Only for BERT
38
+ mask_pos = None # Position of the mask token
39
+
40
+ if prompt:
41
+ special_token_mapping = {
42
+ "cls": tokenizer.cls_token_id,
43
+ "mask": tokenizer.mask_token_id,
44
+ "sep": tokenizer.sep_token_id,
45
+ "sep+": tokenizer.sep_token_id,
46
+ }
47
+ # Get variable list in the template
48
+ if prompt != "auto":
49
+ template = template.replace("[PROMPT]", prompt)
50
+ template_list = template.split("*")
51
+ if prompt == "auto":
52
+ # find cls place
53
+ cls_pos = template_list.index("cls")
54
+ if template_list[cls_pos + 1] == "":
55
+ # For these kinds of cases: *cls**sent_0*_Liver*mask*.*sep+*
56
+ # Prompt is next to sent_0.
57
+ prompt = template_list[cls_pos + 3]
58
+ elif template_list[cls_pos + 1] != "" and (
59
+ template_list[cls_pos + 1].startswith("_")
60
+ ):
61
+ # For these kinds of cases: *cls*_Liver*mask*.*+sent_0**sep+*
62
+ # Prompt is next to cls.
63
+ prompt = template_list[cls_pos + 1]
64
+ if prompt.startswith("_"):
65
+ prompt = prompt[1:]
66
+ segment_id = 0
67
+
68
+ for part in template_list:
69
+ new_tokens = []
70
+ segment_plus_1_flag = False
71
+ if part in special_token_mapping:
72
+ new_tokens.append(special_token_mapping[part])
73
+ if part == "sep+":
74
+ segment_plus_1_flag = True
75
+ elif part[:5] == "sent_" or part[:6] == "+sent_":
76
+ sent_id = int(part.split("_")[1])
77
+ max_len = max_seq_len - 3 - get_prompt_length(tokenizer, prompt)
78
+ # Tokenize and truncate to max_seq_len
79
+ tokens = enc(input_text_list[sent_id])[-max_len:]
80
+ new_tokens += tokens
81
+ else:
82
+ # Just natural language prompt
83
+ part = part.replace("_", " ")
84
+ # handle special case when T5 tokenizer might add an extra space
85
+ if len(part) == 1:
86
+ new_tokens.append(tokenizer.convert_tokens_to_ids(part))
87
+ else:
88
+ new_tokens += enc(part)
89
+
90
+ input_ids += new_tokens
91
+ attention_mask += [1 for i in range(len(new_tokens))]
92
+ token_type_ids += [segment_id for i in range(len(new_tokens))]
93
+
94
+ if segment_plus_1_flag:
95
+ segment_id += 1
96
+
97
+ mask_pos = [input_ids.index(tokenizer.mask_token_id)]
98
+ # Make sure that the masked position is inside the max_length
99
+ assert mask_pos[0] < max_seq_len
100
+
101
+ else:
102
+ input_ids = [tokenizer.cls_token_id]
103
+ attention_mask = [1]
104
+ token_type_ids = [0]
105
+ max_len = max_seq_len - 2
106
+
107
+ for sent_id, input_text in enumerate(input_text_list):
108
+ if input_text is None:
109
+ # Do not have text_b
110
+ continue
111
+ if pd.isna(input_text) or input_text is None:
112
+ # Empty input
113
+ input_text = ""
114
+ input_tokens = enc(input_text)[:max_len] + [tokenizer.sep_token_id]
115
+ input_ids += input_tokens
116
+ attention_mask += [1 for i in range(len(input_tokens))]
117
+ token_type_ids += [sent_id for i in range(len(input_tokens))]
118
+
119
+ return input_ids, attention_mask, token_type_ids, mask_pos
120
+
121
+
122
+ class InferenceDataset(torch.utils.data.Dataset):
123
+ """
124
+ A class for creating the CGMH dataset in PyTorch.
125
+ Currently, this class supports:
126
+ (1) Few-shot data (e.g., train_size=16)
127
+ (2) Small-size data (e.g., train_size>100)
128
+ ---
129
+ Attributes
130
+ data (pd.DataFrame): the CGMH dataset
131
+ tokenizer: a pre-trained HuggingFace tokenizer
132
+ max_seq_len (int): maximum length for a sequence
133
+ template (_type_, optional): template for the model. Defaults to None.
134
+ prompt (_type_, optional): prompt for the model. Defaults to None.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ input_text: str,
140
+ tokenizer,
141
+ max_seq_len: int,
142
+ template=None,
143
+ prompt=None,
144
+ ):
145
+ self.doc = input_text
146
+ self.template = template
147
+ self.prompt = prompt
148
+ self.tokenizer = tokenizer
149
+ self.max_seq_len = max_seq_len
150
+
151
+ def __getitem__(self, idx):
152
+ input_ids, attn_mask, segs, mask_pos = tokenize_multipart_input(
153
+ tokenizer=self.tokenizer,
154
+ input_text_list=[self.doc],
155
+ template=self.template,
156
+ prompt=self.prompt,
157
+ max_seq_len=self.max_seq_len,
158
+ )
159
+ item = {
160
+ "input_ids": input_ids,
161
+ "token_type_ids": segs,
162
+ "attention_mask": attn_mask,
163
+ }
164
+ if self.prompt:
165
+ item["mask_pos"] = mask_pos
166
+ return item
167
+
168
+ def __len__(self):
169
+ return 1
prompt_model_factory.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from transformers import BertModel
3
+ from transformers.models.bert.modeling_bert import (
4
+ BertPreTrainedModel,
5
+ BertOnlyMLMHead,
6
+ )
7
+ import torch
8
+
9
+
10
+ class BertForPromptFinetuning(BertPreTrainedModel):
11
+ def __init__(self, config, use_multi_label_words: bool = False):
12
+ super().__init__(config)
13
+ self.bert = BertModel(config, add_pooling_layer=False)
14
+ self.cls = BertOnlyMLMHead(config)
15
+ # Initialize weights and apply final processing
16
+ self.init_weights()
17
+
18
+ self.label_word_ids = None
19
+ self.use_multi_label_words = use_multi_label_words
20
+
21
+ def forward(
22
+ self,
23
+ input_ids: Optional[torch.Tensor] = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ token_type_ids: Optional[torch.Tensor] = None,
26
+ mask_pos: Optional[torch.Tensor] = None,
27
+ labels: Optional[torch.Tensor] = None,
28
+ output_hidden_states: Optional[bool] = False,
29
+ output_attentions: Optional[bool] = False,
30
+ ):
31
+ if mask_pos is not None:
32
+ mask_pos = mask_pos.squeeze()
33
+ elif mask_pos is None:
34
+ raise ValueError("`mask_pos` should be assigned!")
35
+
36
+ # Encode everything
37
+ outputs = self.bert(
38
+ input_ids,
39
+ attention_mask=attention_mask,
40
+ token_type_ids=token_type_ids,
41
+ output_hidden_states=output_hidden_states,
42
+ output_attentions=output_attentions,
43
+ )
44
+
45
+ # Get <mask> token representation
46
+ sequence_output = outputs[0]
47
+ sequence_mask_output = sequence_output[
48
+ torch.arange(sequence_output.size(0)), mask_pos
49
+ ]
50
+
51
+ # Logits over vocabulary tokens
52
+ # prediction_mask_scores.shape: [batch_size, vocab_size]
53
+ prediction_mask_scores = self.cls(sequence_mask_output)
54
+
55
+ # Return logits for each label
56
+ logits = []
57
+ if self.use_multi_label_words:
58
+ for label_id in self.label_word_ids:
59
+ one_label_logits = []
60
+ # multiple ids in one label_id
61
+ for id in label_id:
62
+ one_label_word_logits = prediction_mask_scores[:, id]
63
+ one_label_logits.append(one_label_word_logits.unsqueeze(-1))
64
+ # one_label_logits: (bs, num_label_words)
65
+ one_label_logits = torch.cat(one_label_logits, -1)
66
+ # Get the max logits to choose the label word
67
+ logits.append(torch.max(one_label_logits, dim=1, keepdim=True)[0])
68
+
69
+ else:
70
+ for label_id in range(len(self.label_word_ids)):
71
+ logits.append(
72
+ prediction_mask_scores[:, self.label_word_ids[label_id]].unsqueeze(
73
+ -1
74
+ )
75
+ )
76
+
77
+ # logits.shape: [batch_size, num_classes]
78
+ logits = torch.sigmoid(torch.cat(logits, -1))
79
+
80
+ loss = None
81
+ if labels is not None:
82
+ loss_fct = torch.nn.BCELoss()
83
+ loss = loss_fct(logits, labels.float())
84
+
85
+ output = (logits, outputs.hidden_states) if output_hidden_states else (logits,)
86
+ output = (output + (outputs.attentions)) if output_attentions else output
87
+
88
+ return ((loss,) + output) if loss is not None else output
utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import os
4
+ import random
5
+ import argparse
6
+ import json
7
+ import pandas as pd
8
+ import numpy as np
9
+ from sklearn.metrics import precision_recall_fscore_support
10
+ from ast import literal_eval
11
+
12
+
13
+ def pred_by_threshold(
14
+ t: float,
15
+ y_true: np.array,
16
+ similarities: np.array,
17
+ classes: dict,
18
+ ):
19
+ preds = (similarities >= t) * 1
20
+ sk_results = precision_recall_fscore_support(
21
+ y_true,
22
+ preds,
23
+ # average="samples", # For calculating sample-wise P and R scores.
24
+ )
25
+ outputs = {
26
+ "f1": np.average(sk_results[2]),
27
+ "P": np.average(sk_results[0]),
28
+ "R": np.average(sk_results[1]),
29
+ }
30
+ for label_name, idx in classes.items():
31
+ outputs[f"{label_name}_f1"] = sk_results[2][idx]
32
+ return outputs
33
+
34
+
35
+ def get_avg_length(dataset: torch.utils.data.Dataset):
36
+ all_lengths = 0
37
+ data_size = len(dataset)
38
+ for i in range(data_size):
39
+ all_lengths += len(dataset[i]["input_ids"])
40
+ return all_lengths / data_size
41
+
42
+
43
+ def load_csv_multi_label(filename: str, col_name: str = "labels") -> pd.DataFrame:
44
+ """Prevent Pandas from converting lists of int into lists of strings.
45
+
46
+ Args:
47
+ filename (str): path of a csv file
48
+ col_name (str, optional): column name of lists of int. Defaults to 'labels'.
49
+
50
+ Returns:
51
+ pd.DataFrame: a Pandas dataframe
52
+ """
53
+ return pd.read_csv(filename, converters={col_name: literal_eval})
54
+
55
+
56
+ def save_logged_results(filename: str, results: dict):
57
+ try:
58
+ old_df = pd.read_csv(filename)
59
+ df = pd.concat([old_df, pd.DataFrame(results)], ignore_index=True)
60
+ except FileNotFoundError:
61
+ df = pd.DataFrame(results)
62
+
63
+ df.to_csv(filename, index=None)
64
+
65
+
66
+ def set_seed(seed):
67
+ """
68
+ Args:
69
+ seed: an integer number to initialize a pseudorandom number generator
70
+ """
71
+ os.environ["PYTHONHASHSEED"] = str(seed)
72
+ random.seed(seed)
73
+ np.random.seed(seed)
74
+ torch.manual_seed(seed)
75
+
76
+ if torch.cuda.is_available():
77
+ torch.cuda.manual_seed(seed)
78
+ # torch.cuda.manual_seed_all(seed) # if using more than one GPUs
79
+ torch.backends.cudnn.deterministic = True
80
+ torch.backends.cudnn.benchmark = False
81
+
82
+
83
+ def save_baseline_table(
84
+ y_preds: list,
85
+ baseline_name: str,
86
+ baseline_result_file: str = "results/baselines.pkl",
87
+ all_doc_idx: list = None,
88
+ ) -> None:
89
+ if Path(baseline_result_file).exists():
90
+ df = pd.read_pickle(baseline_result_file)
91
+ else:
92
+ assert all_doc_idx is not None
93
+ df = pd.DataFrame({"doc_idx": all_doc_idx})
94
+
95
+ df[baseline_name] = y_preds
96
+ df.to_pickle(baseline_result_file)
97
+
98
+
99
+ def load_params(path_of_params):
100
+ with open(path_of_params, "r") as f:
101
+ params = json.load(f)
102
+ return argparse.Namespace(**params)
103
+
104
+
105
+ def get_label_words(classes: list, use_multi_label_words=False) -> list:
106
+ mapping = {
107
+ "cyst": "cyst",
108
+ "HCC": "hcc", # hepatoma
109
+ "cirrhosis": "cirrhosis",
110
+ "post-treatment": "posttreatment",
111
+ "steatosis": "steatosis",
112
+ "metastasis": "metastasis",
113
+ "hemangioma": "hemangioma",
114
+ }
115
+ if use_multi_label_words:
116
+ mapping = {
117
+ "cyst": ["cyst"],
118
+ "HCC": ["hcc", "hepatoma"], # hepatoma
119
+ "cirrhosis": ["cirrhosis"],
120
+ "post-treatment": ["posttreatment"],
121
+ "steatosis": ["steatosis", "steatohepatitis"],
122
+ "metastasis": ["metastasis"],
123
+ "hemangioma": ["hemangioma"],
124
+ }
125
+
126
+ label_words = [mapping[c] for c in classes]
127
+ return label_words
128
+
129
+
130
+ def seed_mapper(data_type: str) -> list:
131
+ mapping = {
132
+ "train_8": [2, 4, 7, 11, 21, 23, 24, 36, 44, 128],
133
+ "train_32": [0, 1, 3, 7, 10],
134
+ }
135
+ if data_type in mapping:
136
+ return mapping[data_type]
137
+ else:
138
+ raise NotImplementedError