File size: 1,986 Bytes
9ed01de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from transformers import AutoConfig


def extend_list(data_list, n, min_n):
    if min_n == 0:
        return []
    while len(data_list) < n:
        data_list.extend(data_list[:n - len(data_list)])
    return data_list


def find_prefix(input_ids, prefix):
    """
    input_ids: [B, N1], no start token
    prefix: [N2, ], no start token
    """
    len_prefix = prefix.shape[0]  # N2
    # Create all possible windows of len_prefix
    input_ids_unfold = input_ids.unfold(1, len_prefix, 1)
    # Check if all elements in the window match the sequence
    matches = (input_ids_unfold == prefix).all(dim=2)
    # Convert boolean matches to integers for argmax operation
    matches_int = matches.type(torch.int64)
    # Calculate indices for the first match, if any, otherwise set to -1
    indices = torch.where(
        matches.any(dim=1),
        matches_int.argmax(dim=1),
        torch.tensor(-1, dtype=torch.int64),
    )
    assert (indices >= 0).all(), "Some inputs do not contain prefix"
    return indices


def auto_upgrade(config):
    cfg = AutoConfig.from_pretrained(config)
    if "mplug_owl2" in config and "mplug_owl2" not in cfg.model_type:
        assert cfg.model_type == "mplug_owl2"
        print(
            "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
        )
        print(
            "You must upgrade the checkpoint to the new code base (this can be done automatically)."
        )
        confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
        if confirm.lower() in ["y", "yes"]:
            print("Upgrading checkpoint...")
            assert len(cfg.architectures) == 1
            setattr(cfg.__class__, "model_type", "mplug_owl2")
            cfg.architectures[0] = "LlavaLlamaForCausalLM"
            cfg.save_pretrained(config)
            print("Checkpoint upgraded.")
        else:
            print("Checkpoint upgrade aborted.")
            exit(1)