Spaces:
Runtime error
Runtime error
<ADD> optimize code
Browse files
app.py
CHANGED
|
@@ -56,17 +56,10 @@ def do_infer_probs(model, exemplar_attn_kv, exemplar_attn_mask, batched_choices_
|
|
| 56 |
|
| 57 |
@torch.no_grad()
|
| 58 |
def process_once(dataset_name, exemplar_str, forward_steps, raw_data):
|
| 59 |
-
model_name, model_size = "opt", "125m"
|
| 60 |
-
step_size, momentum = 0.01, 0.9
|
| 61 |
-
|
| 62 |
setup_cpu(seed=seed)
|
| 63 |
TaskHandler = load_task(dataset_name)
|
| 64 |
task_agent = TaskHandler(prompt_version)
|
| 65 |
|
| 66 |
-
tokenizer = build_tokenizer(model_name, model_size, padding_side="right")
|
| 67 |
-
model = build_model(model_name, model_size, False)
|
| 68 |
-
torch.autograd.set_grad_enabled(False)
|
| 69 |
-
|
| 70 |
processed_data = task_agent.dataset_preprocess(raw_data)
|
| 71 |
dataset = TokenizedForMCRightPad(processed_data, tokenizer, task_agent.multiple_choice_promptify)
|
| 72 |
|
|
@@ -132,6 +125,13 @@ if __name__ == "__main__":
|
|
| 132 |
prompt_version = "default"
|
| 133 |
kv_iter = 10
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
print(f"Dataset: {dataset_name}")
|
| 136 |
task_root = Path("example_sets").joinpath(dataset_name)
|
| 137 |
|
|
|
|
| 56 |
|
| 57 |
@torch.no_grad()
|
| 58 |
def process_once(dataset_name, exemplar_str, forward_steps, raw_data):
|
|
|
|
|
|
|
|
|
|
| 59 |
setup_cpu(seed=seed)
|
| 60 |
TaskHandler = load_task(dataset_name)
|
| 61 |
task_agent = TaskHandler(prompt_version)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
processed_data = task_agent.dataset_preprocess(raw_data)
|
| 64 |
dataset = TokenizedForMCRightPad(processed_data, tokenizer, task_agent.multiple_choice_promptify)
|
| 65 |
|
|
|
|
| 125 |
prompt_version = "default"
|
| 126 |
kv_iter = 10
|
| 127 |
|
| 128 |
+
model_name, model_size = "opt", "125m"
|
| 129 |
+
step_size, momentum = 0.01, 0.9
|
| 130 |
+
setup_cpu(seed=seed)
|
| 131 |
+
tokenizer = build_tokenizer(model_name, model_size, padding_side="right")
|
| 132 |
+
model = build_model(model_name, model_size, False)
|
| 133 |
+
torch.autograd.set_grad_enabled(False)
|
| 134 |
+
|
| 135 |
print(f"Dataset: {dataset_name}")
|
| 136 |
task_root = Path("example_sets").joinpath(dataset_name)
|
| 137 |
|