Johnny050407's picture
Upload folder using huggingface_hub
9ed01de verified
import torch
from transformers import AutoConfig
def extend_list(data_list, n, min_n):
if min_n == 0:
return []
while len(data_list) < n:
data_list.extend(data_list[:n - len(data_list)])
return data_list
def find_prefix(input_ids, prefix):
"""
input_ids: [B, N1], no start token
prefix: [N2, ], no start token
"""
len_prefix = prefix.shape[0] # N2
# Create all possible windows of len_prefix
input_ids_unfold = input_ids.unfold(1, len_prefix, 1)
# Check if all elements in the window match the sequence
matches = (input_ids_unfold == prefix).all(dim=2)
# Convert boolean matches to integers for argmax operation
matches_int = matches.type(torch.int64)
# Calculate indices for the first match, if any, otherwise set to -1
indices = torch.where(
matches.any(dim=1),
matches_int.argmax(dim=1),
torch.tensor(-1, dtype=torch.int64),
)
assert (indices >= 0).all(), "Some inputs do not contain prefix"
return indices
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if "mplug_owl2" in config and "mplug_owl2" not in cfg.model_type:
assert cfg.model_type == "mplug_owl2"
print(
"You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
)
print(
"You must upgrade the checkpoint to the new code base (this can be done automatically)."
)
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "mplug_owl2")
cfg.architectures[0] = "LlavaLlamaForCausalLM"
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)