| from transformers import pipeline |
| import torch |
| import os |
| from tot.models import gpt_24_value, gpt_24_proposal, llama_propose, llama_value |
| from tot.methods.bfs import get_proposals, get_values, get_values_batch, get_proposals_batch, solve_together |
| from tot.tasks.game24 import Game24Task, get_current_numbers |
| import itertools |
| import transformers |
| from tot.models import StopOnEvaluation, Model |
| import time |
| from transformers.generation.stopping_criteria import StopStringCriteria |
|
|
| os.environ["HF_TOKEN"] = os.getenv("HUGGINGTOKEN") |
| |
| model_id = "meta-llama/Llama-3.3-70B-Instruct" |
| tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, padding_side='left') |
| tokenizer.pad_token = tokenizer.eos_token |
| model_pipeline = pipeline( |
| "text-generation", |
| model=model_id, |
| tokenizer=tokenizer, |
| model_kwargs={"torch_dtype": torch.bfloat16}, |
| device_map="auto") |
| model = Model(model_pipeline=model_pipeline, model_id=model_id) |
| task = Game24Task() |
|
|
| def testGPUsValues(model, task): |
| x = "4 5 6 10" |
| ys = ["10 - 6 = 4 (left: 4 5 4)\n4 * 5 = 20 (left: 4 20)\n"] * 128 |
| value_prompts = [task.value_prompt_wrap(x, y, False) for y in ys] |
| values = llama_value(model, value_prompts, 3, 128) |
| print(values) |
|
|
| |
|
|
| def testGPUsProposals(model, task): |
| x = "4 5 6 10" |
| ys = [""] * 256 |
| proposals = get_proposals_batch(task, x, ys, model, 256, False) |
| print(proposals) |
|
|
| |
|
|
| def testProposalsFinalBatch(model, task): |
| x = "4 5 6 10" |
| ys = ["6 * 4 = 24 (left: 5 10 24)\n10 - 5 = 5 (left: 5 5 24)\n5 - 5 = 0 (left: 0 24)\n", |
| "6 * 4 = 24 (left: 5 10 24)\n10 - 5 = 5 (left: 5 5 24)\n24 / 5 = 4.8 (left: 4.8 5)\n", |
| "10 - 6 = 4 (left: 4 5 4)\n4 * 5 = 20 (left: 4 20)\n20 + 4 = 24 (left: 24)\n", |
| "10 - 6 = 4 (left: 4 5 4)\n5 * 4 = 20 (left: 4 20)\n20 + 4 = 24 (left: 24)\n", |
| "10 - 4 = 6 (left: 5 6 6)\n5 * 6 = 30 (left: 6 30)\n30 - 6 = 24 (left: 24)\n", |
| "4 * 5 = 24 (left: 6 10 24)\n10 - 6 = 4 (left: 4 24)\n24 - 4 = 20 (left: 20)\n"] |
| start_time = time.time() |
| proposals = get_proposals_batch(task, x, ys, model, 16, True) |
| end_time = time.time() |
| print(f"Execution time: {end_time - start_time:.4f} seconds") |
| print(proposals) |
|
|
| |
|
|
| def testProposalsBatch(model, task): |
| x = "4 5 6 10" |
| ys = ["5 + 4 = 9 (left: 6 9 10)\n", "6 + 4 = 10 (left: 5 10 10)\n", "5 + 6 = 11 (left: 4 10 11)\n", "6 * 4 = 24 (left: 5 10 24)\n", |
| "10 - 5 = 5 (left: 4 5 6)\n"] |
| start_time = time.time() |
| proposals = get_proposals_batch(task, x, ys, model, 16, False) |
| end_time = time.time() |
| print(f"Execution time: {end_time - start_time:.4f} seconds") |
| print(proposals) |
|
|
| |
|
|
| def testValueBatch(model, task): |
| x = "4 5 6 10" |
| ys = ["5 + 4 = 9 (left: 6 9 10)\n", "6 + 4 = 10 (left: 5 10 10)\n", "5 + 6 = 11 (left: 4 10 11)\n", "6 * 4 = 24 (left: 5 10 24)\n", |
| "10 - 5 = 5 (left: 4 5 6)\n", "10 - 6 = 4 (left: 4 5 4)\n", "10 - 4 = 6 (left: 5 6 6)\n", "10 / 4 = 2.5 (left: 2.5 5 6)\n", |
| "5 - 4 = 1 (left: 1 6 10)\n", "6 / 4 = 1.5 (left: 1.5 5 10)\n"] |
| start_time = time.time() |
| values = get_values_batch(task, x, ys, 3, model, 16, False, True) |
| end_time = time.time() |
| print(f"Execution time: {end_time - start_time:.4f} seconds") |
| print(values) |
|
|
| |
|
|
| def testValueFinalBatch(model, task): |
| x = "4 5 6 10" |
| ys = ["6 * 4 = 24 (left: 5 10 24)\n10 - 5 = 5 (left: 5 5 24)\n5 - 5 = 0 (left: 0 24)\n1 + 1 = 2 (left: 2)\n", |
| "6 * 4 = 24 (left: 5 10 24)\n10 - 5 = 5 (left: 5 5 24)\n24 / 5 = 4.8 (left: 4.8 5)\n1 + 1 = 2 (left: 2)\n", |
| "10 - 6 = 4 (left: 4 5 4)\n4 * 5 = 20 (left: 4 20)\n20 + 4 = 24 (left: 24)\nAnswer: (10 - 6) * 5 + 4 = 24\n", |
| "10 - 6 = 4 (left: 4 5 4)\n5 * 4 = 20 (left: 4 20)\n20 + 4 = 24 (left: 24)\nAnswer: (10 - 6) + 5 * (10 - 6) = 24\n", |
| "10 - 4 = 6 (left: 5 6 6)\n5 * 6 = 30 (left: 6 30)\n30 - 6 = 24 (left: 24)\nAnswer: (10 - 4) * 5 - 6 = 24\n"] |
| start_time = time.time() |
| values = get_values_batch(task, x, ys, 3, model, 16, True, True) |
| end_time = time.time() |
| print(f"Execution time: {end_time - start_time:.4f} seconds") |
| print(values) |
|
|
| |
|
|