File size: 2,123 Bytes
66f1733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from transformers import AutoModelForCausalLM
import re


from evolutiontransformer.worker import (
    load_base_models_if_needed,
    BASE_MODELS,
    inference,
    inference_task,
    merge_models,
)


def get_final_answer(text: str) -> int | None:
    numbers = re.findall(r"\d+", text)
    return int(numbers[-1]) if numbers else None


def test_inference():
    session_id = "test_session"

    print("### Testing inference on SVAMP model...")
    prompt = "If there are 3 cars and 2 bikes, how many vehicles are there in total?\nAnswer:"
    output = inference_task(session_id, "svamp", prompt)
    assert get_final_answer(output["response"]) == 5


def test_merge_models():
    load_base_models_if_needed()

    model_recipe = {
        "layer_recipe": [[(i, "svamp", 1.0)] for i in range(24)],
        "embedding_lambdas": [1.0, 1.0],
        "linear_lambdas": [1.0, 1.0],
    }

    merged_model = merge_models(model_recipe)

    for (name1, param1), (name2, param2) in zip(
        BASE_MODELS["svamp"].named_parameters(), merged_model.named_parameters()
    ):
        assert torch.allclose(param1, param2)


def test_merge_models_with_inference1():
    load_base_models_if_needed()

    model_recipe = {
        "layer_recipe": [
            [(i % 24, "svamp", 1.0 if i < 24 else 0.5)] for i in range(48)
        ],
        "embedding_lambdas": [1.0, 1.0],
        "linear_lambdas": [1.0, 1.0],
    }

    merged_model = merge_models(model_recipe)

    print(
        inference(
            merged_model,
            "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
        )
    )


def test_merge_models_with_inference2():
    load_base_models_if_needed()

    model_recipe = {
        "layer_recipe": [[(i, "tinystories", 1.0)] for i in range(24)],
        "embedding_lambdas": [0.0, 0.0],
        "linear_lambdas": [0.0, 0.0],
    }

    merged_model = merge_models(model_recipe)

    print(
        inference(
            merged_model,
            "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
        )
    )